From b32aab63c444aa3fe91457113b01cbbac581f156 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Wed, 27 Mar 2024 18:54:22 +0800 Subject: [PATCH] update t5 model --- colossalai/shardformer/modeling/t5.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 9c5ce3fb65c9..94f4fce74501 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -118,15 +118,12 @@ def t5_stack_forward( # required mask seq length can be calculated via length of past mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long) - # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -138,7 +135,9 @@ def t5_stack_forward( encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -162,15 +161,8 @@ def t5_stack_forward( torch.cuda.set_device(hidden_states.device) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -180,6 +172,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module(