-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[shardformer] support pipeline base vit model #4284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6f7734e
Feature/vit support (#4182)
klhhhhh da3fc20
support base vit pipeline
FoolPlayer 0eff55a
support vit downstream model
FoolPlayer 8869189
fix vit shard test
FoolPlayer dff8d56
modify hidden states return type
FoolPlayer 7ca55d5
Merge branch 'feature/pipeline' of https://github.com/hpcaitech/Colos…
FoolPlayer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,337 @@ | ||
| import logging | ||
| from typing import Dict, List, Optional, Set, Tuple, Union | ||
|
|
||
| import torch | ||
| from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder | ||
|
|
||
| from colossalai.pipeline.stage_manager import PipelineStageManager | ||
|
|
||
|
|
||
| def _encoder_forward( | ||
| encoder: ViTEncoder, | ||
| start_idx: int, | ||
| end_idx: int, | ||
| hidden_states: torch.Tensor, | ||
| head_mask: Optional[torch.Tensor] = None, | ||
| return_dict: bool = True, | ||
| stage_manager: PipelineStageManager = None, | ||
| ) -> Union[tuple, BaseModelOutput]: | ||
|
|
||
| for i in range(start_idx, end_idx): | ||
| layer_module = encoder.layer[i] | ||
|
|
||
| layer_head_mask = head_mask[i] if head_mask is not None else None | ||
|
|
||
| if encoder.gradient_checkpointing and encoder.training: | ||
|
|
||
| def create_custom_forward(module): | ||
|
|
||
| def custom_forward(*inputs): | ||
| return module(*inputs, False) | ||
|
|
||
| return custom_forward | ||
|
|
||
| layer_outputs = torch.utils.checkpoint.checkpoint( | ||
| create_custom_forward(layer_module), | ||
| hidden_states, | ||
| layer_head_mask, | ||
| ) | ||
| else: | ||
| layer_outputs = layer_module(hidden_states, layer_head_mask, False) | ||
|
|
||
| hidden_states = layer_outputs[0] | ||
| if not stage_manager.is_last_stage(): | ||
| return hidden_states | ||
| else: | ||
| if not return_dict: | ||
| return tuple(hidden_states) | ||
| return BaseModelOutput( | ||
| last_hidden_state=hidden_states, | ||
| hidden_states=None, | ||
| attentions=None, | ||
| ) | ||
|
|
||
|
|
||
| def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): | ||
|
|
||
| from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling | ||
|
|
||
| def pp_forward( | ||
| self, | ||
| pixel_values: Optional[torch.Tensor] = None, | ||
| bool_masked_pos: Optional[torch.BoolTensor] = None, | ||
| head_mask: Optional[torch.Tensor] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| interpolate_pos_encoding: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| hidden_states: Optional[torch.FloatTensor] = None, | ||
| ) -> Union[Tuple, BaseModelOutputWithPooling]: | ||
| r""" | ||
| bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): | ||
| Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). | ||
| """ | ||
|
|
||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
| output_hidden_states = (output_hidden_states | ||
| if output_hidden_states is not None else self.config.output_hidden_states) | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| if output_attentions is not None: | ||
| logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.') | ||
| output_attentions = None | ||
| if output_hidden_states is not None: | ||
| logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.') | ||
| output_hidden_states = None | ||
|
FoolPlayer marked this conversation as resolved.
|
||
|
|
||
| # Prepare head mask if needed | ||
| # 1.0 in head_mask indicate we keep the head | ||
| # attention_probs has shape bsz x n_heads x N x N | ||
| # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | ||
| # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | ||
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | ||
|
|
||
| if stage_manager.is_first_stage(): | ||
| if pixel_values is None: | ||
| raise ValueError("You have to specify pixel_values") | ||
|
|
||
| # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) | ||
| expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype | ||
| if pixel_values.dtype != expected_dtype: | ||
| pixel_values = pixel_values.to(expected_dtype) | ||
|
|
||
| embedding_output = self.embeddings(pixel_values, | ||
| bool_masked_pos=bool_masked_pos, | ||
| interpolate_pos_encoding=interpolate_pos_encoding) | ||
| else: | ||
| assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" | ||
|
|
||
| # Go through encoder | ||
| if not stage_manager.is_last_stage(): | ||
| hidden_states = _encoder_forward( | ||
| encoder=self.encoder, | ||
| start_idx=stage_index[0], | ||
| end_idx=stage_index[1], | ||
| hidden_states=embedding_output, | ||
| head_mask=head_mask, | ||
| return_dict=return_dict, | ||
| stage_manager=stage_manager, | ||
| ) | ||
| return {'hidden_states': hidden_states} | ||
| else: | ||
| encoder_outputs = _encoder_forward( | ||
| encoder=self.encoder, | ||
| start_idx=stage_index[0], | ||
| end_idx=stage_index[1], | ||
| hidden_states=hidden_states, | ||
| head_mask=head_mask, | ||
| return_dict=return_dict, | ||
| stage_manager=stage_manager, | ||
| ) | ||
|
|
||
| # Go through rest layers | ||
| sequence_output = encoder_outputs[0] | ||
| sequence_output = self.layernorm(sequence_output) | ||
| pooled_output = self.pooler(sequence_output) if self.pooler is not None else None | ||
|
|
||
| if not return_dict: | ||
| head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) | ||
| return head_outputs + encoder_outputs[1:] | ||
|
|
||
| return BaseModelOutputWithPooling( | ||
| last_hidden_state=sequence_output, | ||
| pooler_output=pooled_output, | ||
| hidden_states=encoder_outputs.hidden_states, | ||
| attentions=encoder_outputs.attentions, | ||
| ) | ||
|
|
||
| return pp_forward | ||
|
|
||
|
|
||
| def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): | ||
|
|
||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
| from transformers.models.vit.modeling_vit import ImageClassifierOutput | ||
|
|
||
| def pp_forward( | ||
| self, | ||
| pixel_values: Optional[torch.Tensor] = None, | ||
| head_mask: Optional[torch.Tensor] = None, | ||
| labels: Optional[torch.Tensor] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| interpolate_pos_encoding: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| hidden_states: Optional[torch.FloatTensor] = None, | ||
| ) -> Union[tuple, ImageClassifierOutput]: | ||
| r""" | ||
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
| Labels for computing the image classification/regression loss. Indices should be in `[0, ..., | ||
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | ||
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
| """ | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| if not stage_manager.is_first_stage(): | ||
| assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" | ||
|
|
||
| outputs = self.vit( | ||
| pixel_values, | ||
| head_mask=head_mask, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| interpolate_pos_encoding=interpolate_pos_encoding, | ||
| return_dict=return_dict, | ||
| hidden_states=hidden_states, | ||
| ) | ||
|
|
||
| # not last stage, return hidden_states | ||
| if not stage_manager.is_last_stage(): | ||
| return outputs | ||
| else: | ||
| sequence_output = outputs[0] | ||
|
|
||
| # last stage | ||
| logits = self.classifier(sequence_output[:, 0, :]) | ||
|
|
||
| loss = None | ||
| if labels is not None: | ||
| # move labels to correct device to enable model parallelism | ||
| labels = labels.to(logits.device) | ||
| if self.config.problem_type is None: | ||
| if self.num_labels == 1: | ||
| self.config.problem_type = "regression" | ||
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||
| self.config.problem_type = "single_label_classification" | ||
| else: | ||
| self.config.problem_type = "multi_label_classification" | ||
|
|
||
| if self.config.problem_type == "regression": | ||
| loss_fct = MSELoss() | ||
| if self.num_labels == 1: | ||
| loss = loss_fct(logits.squeeze(), labels.squeeze()) | ||
| else: | ||
| loss = loss_fct(logits, labels) | ||
| elif self.config.problem_type == "single_label_classification": | ||
| loss_fct = CrossEntropyLoss() | ||
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | ||
| elif self.config.problem_type == "multi_label_classification": | ||
| loss_fct = BCEWithLogitsLoss() | ||
| loss = loss_fct(logits, labels) | ||
|
|
||
| if not return_dict: | ||
| output = (logits,) + outputs[1:] | ||
| return ((loss,) + output) if loss is not None else output | ||
|
|
||
| return ImageClassifierOutput( | ||
| loss=loss, | ||
| logits=logits, | ||
| hidden_states=outputs.hidden_states, | ||
| attentions=outputs.attentions, | ||
| ) | ||
|
|
||
| return pp_forward | ||
|
|
||
|
|
||
| def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): | ||
|
|
||
| import math | ||
|
|
||
| import torch.nn as nn | ||
| from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput | ||
|
|
||
| def pp_forward( | ||
| self, | ||
| pixel_values: Optional[torch.Tensor] = None, | ||
| bool_masked_pos: Optional[torch.BoolTensor] = None, | ||
| head_mask: Optional[torch.Tensor] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| interpolate_pos_encoding: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| hidden_states: Optional[torch.FloatTensor] = None, | ||
| ) -> Union[tuple, ImageClassifierOutput]: | ||
| r""" | ||
| bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): | ||
| Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). | ||
|
|
||
| Returns: | ||
|
|
||
| Examples: | ||
| ```python | ||
| >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling | ||
| >>> import torch | ||
| >>> from PIL import Image | ||
| >>> import requests | ||
|
|
||
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
| >>> image = Image.open(requests.get(url, stream=True).raw) | ||
|
|
||
| >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | ||
| >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") | ||
|
|
||
| >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 | ||
| >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values | ||
| >>> # create random boolean mask of shape (batch_size, num_patches) | ||
| >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() | ||
|
|
||
| >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) | ||
| >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction | ||
| >>> list(reconstructed_pixel_values.shape) | ||
| [1, 3, 224, 224] | ||
| ```""" | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): | ||
| raise ValueError( | ||
| "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " | ||
| "the reconstructed image has the same dimensions as the input." | ||
| f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.") | ||
|
|
||
| if not stage_manager.is_first_stage(): | ||
| assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" | ||
|
|
||
| outputs = self.vit(pixel_values, | ||
| bool_masked_pos=bool_masked_pos, | ||
| head_mask=head_mask, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| interpolate_pos_encoding=interpolate_pos_encoding, | ||
| return_dict=return_dict, | ||
| hidden_states=hidden_states) | ||
| if not stage_manager.is_last_stage(): | ||
| return outputs | ||
| else: | ||
| sequence_output = outputs[0] | ||
|
|
||
| # Reshape to (batch_size, num_channels, height, width) | ||
| sequence_output = sequence_output[:, 1:] | ||
| batch_size, sequence_length, num_channels = sequence_output.shape | ||
| height = width = math.floor(sequence_length**0.5) | ||
| sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) | ||
|
|
||
| # Reconstruct pixel values | ||
| reconstructed_pixel_values = self.decoder(sequence_output) | ||
|
|
||
| masked_im_loss = None | ||
| if bool_masked_pos is not None: | ||
| size = self.config.image_size // self.config.patch_size | ||
| bool_masked_pos = bool_masked_pos.reshape(-1, size, size) | ||
| mask = (bool_masked_pos.repeat_interleave(self.config.patch_size, | ||
| 1).repeat_interleave(self.config.patch_size, | ||
| 2).unsqueeze(1).contiguous()) | ||
| reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") | ||
| masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels | ||
|
|
||
| if not return_dict: | ||
| output = (reconstructed_pixel_values,) + outputs[1:] | ||
| return ((masked_im_loss,) + output) if masked_im_loss is not None else output | ||
|
|
||
| return MaskedImageModelingOutput( | ||
| loss=masked_im_loss, | ||
| reconstruction=reconstructed_pixel_values, | ||
| hidden_states=outputs.hidden_states, | ||
| attentions=outputs.attentions, | ||
| ) | ||
|
|
||
| return pp_forward | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.