diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 1990d7df3279..187e35e40dd4 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -123,11 +123,9 @@ def gptj_model_forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) # position id to be assigned not just for the first stage for attn input - if position_ids is not None: - position_ids = position_ids.view(-1, seq_length) - else: + if position_ids is None: position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_ids = position_ids.unsqueeze(0) if stage_manager.is_first_stage(): if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -172,21 +170,15 @@ def gptj_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( @@ -603,7 +595,9 @@ def forward( value = torch.cat((past_value, value), dim=1) if use_cache is True: - present = (key, value) + # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. + # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 + present = (key.to(hidden_states.dtype), value) else: present = None