diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py index 934b99b83ea1..46aa3b52af8f 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm.py @@ -90,7 +90,31 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=ChatGLMModel) + else: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), + SubModuleReplacementDescription(suffix="post_attention_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=GLMBlock) + + if self.model.config.post_layer_norm: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="encoder.final_layernorm", + target_module=col_nn.FusedRMSNorm) + ], + policy=policy, + target_key=ChatGLMModel) + return policy def postprocess(self): return self.model + + +class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): + + def module_policy(self): + policy = super().module_policy() + return policy diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 7b035afae22c..d45055bc8beb 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -2,7 +2,13 @@ import torch.nn as nn -from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import ( + DropoutForParallelInput, + DropoutForReplicatedInput, + FusedLayerNorm, + Linear1D_Col, + Linear1D_Row, +) from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -18,101 +24,112 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel policy = {} if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]) - - policy[ViTLayer] = ModulePolicyDescription( - attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size//self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=DropoutForReplicatedInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=DropoutForReplicatedInput, - ), - ] - ) + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ]) + + policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=DropoutForReplicatedInput, + ), + ]) + + if self.shard_config.enable_fused_normalization: + policy[ViTModel] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="layernorm", + target_module=FusedLayerNorm, + ) + ]) + + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="layernorm_before", target_module=FusedLayerNorm), + SubModuleReplacementDescription(suffix="layernorm_after", target_module=FusedLayerNorm) + ], + policy=policy, + target_key=ViTLayer) return policy - - + def new_model_class(self): return None def postprocess(self): return self.model + class ViTForImageClassificationPolicy(ViTPolicy): - def module_policy(self): + def module_policy(self): from transformers.models.vit.modeling_vit import ViTForImageClassification policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: new_item = { ViTForImageClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription(suffix="classifier", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) return policy + class ViTForMaskedImageModelingPolicy(ViTPolicy): - + def module_policy(self): policy = super().module_policy() return policy - - - - diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py index 1408babede64..04e73a832abe 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm.py @@ -3,7 +3,7 @@ from ..registry import ModelAttribute, model_zoo from .chatglm2_6b.configuration_chatglm import ChatGLMConfig -from .chatglm2_6b.modeling_chatglm import ChatGLMModel +from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel # ================================ # Register single-sentence ChatGLM @@ -21,7 +21,7 @@ def data_gen(): # define loss function loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean() -loss_fn = lambda x: x.loss +loss_fn = lambda x: x.logits.mean() config = ChatGLMConfig(num_layers=1, padded_vocab_size=65024, hidden_size=64, @@ -36,3 +36,10 @@ def data_gen(): output_transform_fn=output_transform_fn, loss_fn=loss_fn_for_chatglm_model, model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name="transformers_chatglm_for_conditional_generation", + model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index 2cdf5da2e6da..a0fa4bd82e74 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -7,7 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy +from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -85,6 +85,8 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism): shard_former = ShardFormer(shard_config=shard_config) if name == "transformers_chatglm": sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda() + else: + sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda() check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()