diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index ab141a74aef8..80d74c2960ac 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -14,6 +14,8 @@ def _encoder_forward( end_idx: int, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, return_dict: bool = True, stage_manager: PipelineStageManager = None, ) -> Union[tuple, BaseModelOutput]: @@ -23,20 +25,14 @@ def _encoder_forward( layer_head_mask = head_mask[i] if head_mask is not None else None if encoder.gradient_checkpointing and encoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, False) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - layer_head_mask, - ) + layer_outputs = encoder._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) else: - layer_outputs = layer_module(hidden_states, layer_head_mask, False) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if not stage_manager.is_last_stage(): @@ -112,6 +108,8 @@ def pp_forward( end_idx=stage_index[1], hidden_states=hidden_states, head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, )