From 6a528b5bd46d3a015a3877eb806fa8ae8d1d41da Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Thu, 28 Mar 2024 11:33:05 +0800 Subject: [PATCH 1/2] update vit model --- colossalai/shardformer/modeling/vit.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index ab141a74aef8..c74813684272 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -14,29 +14,27 @@ def _encoder_forward( end_idx: int, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, return_dict: bool = True, stage_manager: PipelineStageManager = None, ) -> Union[tuple, BaseModelOutput]: for i in range(start_idx, end_idx): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) layer_module = encoder.layer[i] layer_head_mask = head_mask[i] if head_mask is not None else None if encoder.gradient_checkpointing and encoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, False) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - layer_head_mask, - ) + layer_outputs = encoder._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) else: - layer_outputs = layer_module(hidden_states, layer_head_mask, False) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if not stage_manager.is_last_stage(): @@ -112,6 +110,8 @@ def pp_forward( end_idx=stage_index[1], hidden_states=hidden_states, head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, ) From 11265014490db7a6850bf1923c8733f28b2f5043 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Wed, 3 Apr 2024 13:15:06 +0800 Subject: [PATCH 2/2] remove the output_hidden_states --- colossalai/shardformer/modeling/vit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index c74813684272..80d74c2960ac 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -20,8 +20,6 @@ def _encoder_forward( stage_manager: PipelineStageManager = None, ) -> Union[tuple, BaseModelOutput]: for i in range(start_idx, end_idx): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) layer_module = encoder.layer[i] layer_head_mask = head_mask[i] if head_mask is not None else None