From 6f7734e80c2e5b3361ac340eda83a79894c4741c Mon Sep 17 00:00:00 2001 From: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Date: Fri, 7 Jul 2023 14:06:46 +0800 Subject: [PATCH 1/5] Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout --- .../shardformer/policies/auto_policy.py | 8 + colossalai/shardformer/policies/vit.py | 170 +++++++++--------- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/vit.py | 60 +++++++ .../test_model/test_shard_vit.py | 62 +++++-- 5 files changed, 201 insertions(+), 100 deletions(-) create mode 100644 tests/kit/model_zoo/transformers/vit.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index ccdb33b2efe5..e27c2cd9f0ef 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -73,6 +73,14 @@ class PolicyLocation: "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + # ViT + "transformers.models.vit.modeling_vit.ViTModel": + PolicyLocation(file_name="vit", class_name="ViTPolicy"), + "transformers.models.vit.modeling_vit.ViTForImageClassification": + PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"), + "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": + PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"), + # OPT "transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 3f6bbd10607a..c5fa30c1901b 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -2,11 +2,11 @@ import torch.nn as nn -from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ViTPolicy'] +__all__ = ['ViTPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] class ViTPolicy(Policy): @@ -15,96 +15,104 @@ def config_sanity_check(self): pass def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer - base_policy = { - ViTEmbeddings: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]), - ViTLayer: - ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - }, + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=DropoutForParallelInput, - ), - ]), - } - - # optimization configuration - if self.shard_config.enable_fused_normalization: - base_policy[ViTAttention].sub_module_replacement.extend([ - SubModuleReplacementDescription( - suffix="layernorm_before", - target_module=FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layernorm_after", - target_module=FusedLayerNorm, + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ]) + + policy[ViTLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size//self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=DropoutForReplicatedInput, + ), + ] ) - ]) - base_policy[ViTModel].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="layernorm", - target_module=FusedLayerNorm, - )) - - return base_policy + return policy + + def new_model_class(self): return None def postprocess(self): return self.model + +class ViTForImageClassificationPolicy(ViTPolicy): + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForImageClassification + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + new_item = { + ViTForImageClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription(suffix="classifier", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy + +class ViTForMaskedImageModelingPolicy(ViTPolicy): + + def module_policy(self): + policy = super().module_policy() + return policy + + + + diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 4aa01abe13ee..a298767d12e7 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -5,3 +5,4 @@ from .llama import * from .opt import * from .t5 import * +from .vit import * diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py new file mode 100644 index 000000000000..1c86c7ebc742 --- /dev/null +++ b/tests/kit/model_zoo/transformers/vit.py @@ -0,0 +1,60 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence VIT +# =============================== + +config = transformers.ViTConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4) + +# define data gen function +def data_gen(): + pixel_values = torch.randn(1, 3, 224, 224) + return dict(pixel_values = pixel_values) + +def data_gen_for_masked_image_modeling(): + data = data_gen() + num_patches = (config.image_size // config.patch_size) ** 2 + bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + data['bool_masked_pos'] = bool_masked_pos + return data + +# define output transform function +output_transform_fn = lambda x: x + +# function to get the loss +loss_fn_for_vit_model = lambda x : x.pooler_output.mean() +loss_fn_for_image_classification = lambda x : x.logits.mean() +loss_fn_for_masked_image_modeling = lambda x : x.loss + +# register the following models +# transformers.ViTModel, +# transformers.ViTForMaskedImageModeling, +# transformers.ViTForImageClassification, +model_zoo.register(name = 'transformers_vit', + model_fn = lambda : transformers.ViTModel(config), + data_gen_fn = data_gen, + output_transform_fn = output_transform_fn, + loss_fn = loss_fn_for_vit_model, + model_attribute = ModelAttribute(has_control_flow=True)) + +model_zoo.register(name = 'transformers_vit_for_masked_image_modeling', + model_fn = lambda : transformers.ViTForMaskedImageModeling(config), + data_gen_fn = data_gen_for_masked_image_modeling, + output_transform_fn = output_transform_fn, + loss_fn = loss_fn_for_masked_image_modeling, + model_attribute = ModelAttribute(has_control_flow=True)) + +model_zoo.register(name = 'transformers_vit_for_image_classification', + model_fn = lambda : transformers.ViTForImageClassification(config), + data_gen_fn = data_gen, + output_transform_fn = output_transform_fn, + loss_fn = loss_fn_for_image_classification, + model_attribute = ModelAttribute(has_control_flow=True)) + + diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index af1605b6b659..a96fd02ae746 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -1,9 +1,18 @@ +import os + import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -12,44 +21,59 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output) - + assert_hf_output_close(org_output, shard_output, atol=1e-4, rtol=1e-4) # do backward org_loss.backward() shard_loss.backward() - # check grad - org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad - shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad - - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # unwrap model + if org_model.__class__.__name__ == 'ViTModel': + vit_model = org_model + shard_vit_model = sharded_model + else: + vit_model = org_model.vit + shard_vit_model = sharded_model.vit -def check_vit(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # check attention grad + org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_test() + + @pytest.mark.dist @pytest.mark.skip @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): - spawn(check_vit, 4) + spawn(check_vit, 2) if __name__ == "__main__": From da3fc204ba5271329a3270724104f11932b20e3d Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 19 Jul 2023 16:17:21 +0800 Subject: [PATCH 2/5] support base vit pipeline --- colossalai/shardformer/modeling/vit.py | 136 ++++++++++++++++++ .../shardformer/policies/auto_policy.py | 2 +- colossalai/shardformer/policies/vit.py | 78 ++++++++-- tests/kit/model_zoo/transformers/vit.py | 64 +++++---- .../test_model/test_shard_vit.py | 3 +- .../test_model/test_shard_vit_pipeline.py | 73 ++++++++++ 6 files changed, 313 insertions(+), 43 deletions(-) create mode 100644 colossalai/shardformer/modeling/vit.py create mode 100644 tests/test_shardformer/test_model/test_shard_vit_pipeline.py diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py new file mode 100644 index 000000000000..2d9f8a0967bb --- /dev/null +++ b/colossalai/shardformer/modeling/vit.py @@ -0,0 +1,136 @@ +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from transformers.models.vit.modeling_vit import BaseModelOutput, BaseModelOutputWithPooling, ViTEncoder + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +def forward_fn(stage_manager: PipelineStageManager, stage_index: List[int]): + + def _encoder_forward( + encoder: ViTEncoder, + start_idx: int, + end_idx: int, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + stage_manager: PipelineStageManager = None, + ) -> Union[tuple, BaseModelOutput]: + + for i in range(start_idx, end_idx): + layer_module = encoder.layer[i] + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if encoder.gradient_checkpointing and encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, False) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, False) + + hidden_states = layer_outputs[0] + if stage_manager.is_first_stage(): + return hidden_states + else: + if not return_dict: + return tuple(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=None, + attentions=None, + ) + + def vit_pipeline_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if stage_manager.is_first_stage(): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings(pixel_values, + bool_masked_pos=bool_masked_pos, + interpolate_pos_encoding=interpolate_pos_encoding) + else: + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + # Go through encoder + if stage_manager.is_first_stage(): + hidden_states = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=embedding_output, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + return hidden_states + else: + encoder_outputs = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=hidden_states, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + + # Go through rest layers + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + return vit_pipeline_forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index e27c2cd9f0ef..e8d1f34510a7 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -75,7 +75,7 @@ class PolicyLocation: # ViT "transformers.models.vit.modeling_vit.ViTModel": - PolicyLocation(file_name="vit", class_name="ViTPolicy"), + PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"), "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index c5fa30c1901b..85b2547d3ac8 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,12 +1,15 @@ -from typing import Dict, Union +from functools import partial +from typing import Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +from ..modeling.vit import forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ViTPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] +__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] class ViTPolicy(Policy): @@ -89,6 +92,66 @@ def new_model_class(self): def postprocess(self): return self.model + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + if self.model.__class__.__name__ == 'ViTModel': + module = self.model + else: + module = self.model.vit + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'ViTModel': + module = self.model + else: + module = self.model.vit + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + +# ViTModel +class ViTModelPolicy(ViTPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=forward_fn, policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return super().get_shared_params() + + +# ViTForImageClassification class ViTForImageClassificationPolicy(ViTPolicy): def module_policy(self): @@ -107,12 +170,9 @@ def module_policy(self): policy.update(new_item) return policy -class ViTForMaskedImageModelingPolicy(ViTPolicy): - - def module_policy(self): - policy = super().module_policy() - return policy - - +# ViTForMaskedImageModeling +class ViTForMaskedImageModelingPolicy(ViTPolicy): + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index 1c86c7ebc742..ac133b5301a6 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -7,54 +7,56 @@ # Register single-sentence VIT # =============================== -config = transformers.ViTConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4) +config = transformers.ViTConfig( + num_hidden_layers=4, + # hidden_size=128, + # intermediate_size=256, + num_attention_heads=4) + # define data gen function def data_gen(): pixel_values = torch.randn(1, 3, 224, 224) - return dict(pixel_values = pixel_values) + return dict(pixel_values=pixel_values) + def data_gen_for_masked_image_modeling(): data = data_gen() - num_patches = (config.image_size // config.patch_size) ** 2 + num_patches = (config.image_size // config.patch_size)**2 bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() data['bool_masked_pos'] = bool_masked_pos return data + # define output transform function output_transform_fn = lambda x: x # function to get the loss -loss_fn_for_vit_model = lambda x : x.pooler_output.mean() -loss_fn_for_image_classification = lambda x : x.logits.mean() -loss_fn_for_masked_image_modeling = lambda x : x.loss +loss_fn_for_vit_model = lambda x: x.pooler_output.mean() +loss_fn_for_image_classification = lambda x: x.logits.mean() +loss_fn_for_masked_image_modeling = lambda x: x.loss # register the following models # transformers.ViTModel, # transformers.ViTForMaskedImageModeling, # transformers.ViTForImageClassification, -model_zoo.register(name = 'transformers_vit', - model_fn = lambda : transformers.ViTModel(config), - data_gen_fn = data_gen, - output_transform_fn = output_transform_fn, - loss_fn = loss_fn_for_vit_model, - model_attribute = ModelAttribute(has_control_flow=True)) - -model_zoo.register(name = 'transformers_vit_for_masked_image_modeling', - model_fn = lambda : transformers.ViTForMaskedImageModeling(config), - data_gen_fn = data_gen_for_masked_image_modeling, - output_transform_fn = output_transform_fn, - loss_fn = loss_fn_for_masked_image_modeling, - model_attribute = ModelAttribute(has_control_flow=True)) - -model_zoo.register(name = 'transformers_vit_for_image_classification', - model_fn = lambda : transformers.ViTForImageClassification(config), - data_gen_fn = data_gen, - output_transform_fn = output_transform_fn, - loss_fn = loss_fn_for_image_classification, - model_attribute = ModelAttribute(has_control_flow=True)) - - +model_zoo.register(name='transformers_vit', + model_fn=lambda: transformers.ViTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_vit_model, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_vit_for_masked_image_modeling', + model_fn=lambda: transformers.ViTForMaskedImageModeling(config), + data_gen_fn=data_gen_for_masked_image_modeling, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_masked_image_modeling, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_vit_for_image_classification', + model_fn=lambda: transformers.ViTForImageClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_image_classification, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index a96fd02ae746..13cde0e611ff 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -50,7 +50,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - + @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @@ -69,7 +69,6 @@ def check_vit(rank, world_size, port): @pytest.mark.dist -@pytest.mark.skip @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py new file mode 100644 index 000000000000..1f5cc33a95b3 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py @@ -0,0 +1,73 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_pipeline_model + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # TODO: add tests for forward/backward later + pass + + +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('enable_fused_normalization', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_vit +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + print(name) + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + pixel_values = inputs['pixel_values'] + batch_size = len(pixel_values) + hidden_size = 768 + hidden_state_shape = (batch_size, 197, hidden_size) + + if not stage_manager.is_first_stage(): + # change inputs if not the first stage + hidden_states = torch.randn(*hidden_state_shape).cuda() + inputs['pixel_values'] = None + inputs['hidden_states'] = hidden_states + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + sharded_model.train() + + output = sharded_model(**inputs) + if stage_manager.is_last_stage(): + if name != 'transformers_vit': + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape + + torch.cuda.empty_cache() + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +if __name__ == "__main__": + test_vit() From 0eff55a3b9b39a59c27f0baa97d492e507eaecf1 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 20 Jul 2023 10:44:23 +0800 Subject: [PATCH 3/5] support vit downstream model --- colossalai/shardformer/modeling/vit.py | 281 +++++++++++++++--- colossalai/shardformer/policies/vit.py | 72 ++++- tests/kit/model_zoo/transformers/vit.py | 8 +- .../test_model/test_shard_vit_pipeline.py | 6 +- 4 files changed, 313 insertions(+), 54 deletions(-) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 2d9f8a0967bb..48c7c2a02c79 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,58 +1,62 @@ +import logging from typing import Dict, List, Optional, Set, Tuple, Union import torch -from transformers.models.vit.modeling_vit import BaseModelOutput, BaseModelOutputWithPooling, ViTEncoder +from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder from colossalai.pipeline.stage_manager import PipelineStageManager -def forward_fn(stage_manager: PipelineStageManager, stage_index: List[int]): +def _encoder_forward( + encoder: ViTEncoder, + start_idx: int, + end_idx: int, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + stage_manager: PipelineStageManager = None, +) -> Union[tuple, BaseModelOutput]: - def _encoder_forward( - encoder: ViTEncoder, - start_idx: int, - end_idx: int, - hidden_states: torch.Tensor, - head_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - stage_manager: PipelineStageManager = None, - ) -> Union[tuple, BaseModelOutput]: + for i in range(start_idx, end_idx): + layer_module = encoder.layer[i] - for i in range(start_idx, end_idx): - layer_module = encoder.layer[i] + layer_head_mask = head_mask[i] if head_mask is not None else None - layer_head_mask = head_mask[i] if head_mask is not None else None + if encoder.gradient_checkpointing and encoder.training: - if encoder.gradient_checkpointing and encoder.training: + def create_custom_forward(module): - def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, False) - def custom_forward(*inputs): - return module(*inputs, False) + return custom_forward - return custom_forward + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, False) + + hidden_states = layer_outputs[0] + if not stage_manager.is_last_stage(): + return hidden_states + else: + if not return_dict: + return tuple(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=None, + attentions=None, + ) - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - layer_head_mask, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, False) - hidden_states = layer_outputs[0] - if stage_manager.is_first_stage(): - return hidden_states - else: - if not return_dict: - return tuple(hidden_states) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=None, - attentions=None, - ) +def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): - def vit_pipeline_forward( + from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling + + def pp_forward( self, pixel_values: Optional[torch.Tensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None, @@ -67,11 +71,19 @@ def vit_pipeline_forward( bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if output_attentions is not None: + logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.') + output_attentions = None + if output_hidden_states is not None: + logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.') + output_hidden_states = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -95,7 +107,7 @@ def vit_pipeline_forward( assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" # Go through encoder - if stage_manager.is_first_stage(): + if not stage_manager.is_last_stage(): hidden_states = _encoder_forward( encoder=self.encoder, start_idx=stage_index[0], @@ -133,4 +145,193 @@ def vit_pipeline_forward( attentions=encoder_outputs.attentions, ) - return vit_pipeline_forward + return pp_forward + + +def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + from transformers.models.vit.modeling_vit import ImageClassifierOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if not stage_manager.is_first_stage(): + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states, + ) + + # not last stage, return hidden_states + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # last stage + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward + + +def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + import math + + import torch.nn as nn + from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 224, 224] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input." + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.") + + if not stage_manager.is_first_stage(): + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit(pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states) + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output[:, 1:] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = (bool_masked_pos.repeat_interleave(self.config.patch_size, + 1).repeat_interleave(self.config.patch_size, + 2).unsqueeze(1).contiguous()) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 85b2547d3ac8..d0f2d47f84fa 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -6,7 +6,11 @@ from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row -from ..modeling.vit import forward_fn +from ..modeling.vit import ( + ViTForImageClassification_pipeline_forward, + ViTForMaskedImageModeling_pipeline_forward, + ViTModel_pipeline_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] @@ -108,9 +112,6 @@ def get_held_layers(self) -> List[nn.Module]: held_layers.append(module.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.layernorm) - held_layers.append(module.pooler) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): @@ -141,21 +142,27 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=forward_fn, policy=policy) + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) return policy def get_held_layers(self) -> List[nn.Module]: - return super().get_held_layers() + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" - def get_shared_params(self) -> List[Dict[int, Tensor]]: - return super().get_shared_params() + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) + + return held_layers # ViTForImageClassification class ViTForImageClassificationPolicy(ViTPolicy): - def module_policy(self): - from transformers.models.vit.modeling_vit import ViTForImageClassification + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: @@ -168,11 +175,56 @@ def module_policy(self): ]) } policy.update(new_item) + + from transformers.models.vit.modeling_vit import ViTForImageClassification + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward(model_cls=ViTForImageClassification, + pipeline_forward=ViTForImageClassification_pipeline_forward, + policy=policy) + return policy + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) + + return held_layers + # ViTForMaskedImageModeling class ViTForMaskedImageModelingPolicy(ViTPolicy): def __init__(self) -> None: super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling, + pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) + + return held_layers diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index ac133b5301a6..93a8d6c615d7 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -20,6 +20,12 @@ def data_gen(): return dict(pixel_values=pixel_values) +def data_gen_for_image_classification(): + data = data_gen() + data['labels'] = torch.tensor([0]) + return data + + def data_gen_for_masked_image_modeling(): data = data_gen() num_patches = (config.image_size // config.patch_size)**2 @@ -56,7 +62,7 @@ def data_gen_for_masked_image_modeling(): model_zoo.register(name='transformers_vit_for_image_classification', model_fn=lambda: transformers.ViTForImageClassification(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_image_classification, output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_image_classification, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py index 1f5cc33a95b3..5c3d77a62935 100644 --- a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py @@ -28,7 +28,7 @@ def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - print(name) + inputs = data_gen_fn() inputs = {k: v.cuda() for k, v in inputs.items()} pixel_values = inputs['pixel_values'] @@ -39,7 +39,7 @@ def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy if not stage_manager.is_first_stage(): # change inputs if not the first stage hidden_states = torch.randn(*hidden_state_shape).cuda() - inputs['pixel_values'] = None + # inputs['pixel_values'] = None inputs['hidden_states'] = hidden_states _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, @@ -51,7 +51,7 @@ def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy if name != 'transformers_vit': assert output.loss is not None else: - assert output['hidden_states'].shape == hidden_state_shape + assert output.shape == hidden_state_shape, f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}' torch.cuda.empty_cache() From 886918949e5efb3535832ed92ed57bd04fc6559b Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 20 Jul 2023 12:01:31 +0800 Subject: [PATCH 4/5] fix vit shard test --- tests/test_shardformer/test_model/test_shard_vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 13cde0e611ff..2b02c83e0d27 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -21,7 +21,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, atol=1e-4, rtol=1e-4) + assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3) # do backward org_loss.backward() shard_loss.backward() From dff8d56a9310d51e7bcbca57f37ea5eaae749fa1 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 24 Jul 2023 11:11:31 +0800 Subject: [PATCH 5/5] modify hidden states return type --- colossalai/shardformer/modeling/vit.py | 2 +- colossalai/shardformer/policies/vit.py | 125 +++++++++--------- .../test_model/test_shard_vit_pipeline.py | 3 +- 3 files changed, 62 insertions(+), 68 deletions(-) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 48c7c2a02c79..f28c13ad0aa2 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -117,7 +117,7 @@ def pp_forward( return_dict=return_dict, stage_manager=stage_manager, ) - return hidden_states + return {'hidden_states': hidden_states} else: encoder_outputs = _encoder_forward( encoder=self.encoder, diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index d0f2d47f84fa..47f2c58fc436 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -2,9 +2,8 @@ from typing import Callable, Dict, List, Union import torch.nn as nn -from torch import Tensor -from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +import colossalai.shardformer.layer as col_nn from ..modeling.vit import ( ViTForImageClassification_pipeline_forward, @@ -31,65 +30,62 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]) - - policy[ViTLayer] = ModulePolicyDescription( - attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size//self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=DropoutForReplicatedInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=DropoutForReplicatedInput, - ), - ] - ) + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ]) + + policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) return policy - - + def new_model_class(self): return None @@ -168,16 +164,13 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: new_item = { ViTForImageClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription(suffix="classifier", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) - from transformers.models.vit.modeling_vit import ViTForImageClassification - if self.shard_config.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) self.set_pipeline_forward(model_cls=ViTForImageClassification, diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py index 5c3d77a62935..114992a2a2a5 100644 --- a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py @@ -51,7 +51,8 @@ def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy if name != 'transformers_vit': assert output.loss is not None else: - assert output.shape == hidden_state_shape, f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}' + assert output['hidden_states'].shape == hidden_state_shape, \ + f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}' torch.cuda.empty_cache()