diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 9ea3d95de5b2..85e6d509c81b 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -126,3 +126,28 @@ def postprocess(self) -> nn.Module: the classifier layer """ pass + + def append_or_create_submodule_replacement( + self, description: Union[SubModuleReplacementDescription, + List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module], + ModulePolicyDescription], + target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new submodule replacement description to the policy for the given key. + + Args: + submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + # convert to list + if isinstance(description, SubModuleReplacementDescription): + description = [description] + + # append or create a new description + if target_key in policy: + policy[target_key].sub_module_replacement.extend(description) + else: + policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) + + return policy diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 5ab8fb825244..9c2736cc64d3 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -33,89 +33,114 @@ def preprocess(self): def module_policy(self): from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer - base_policy = { - BertLayer: - ModulePolicyDescription( - attribute_replacement={ - # 1. shard hidden size - "attention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "crossattention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - # 2. shard number of heads - "attention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "crossattention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.self.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]), - BertEmbeddings: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ) - ]) - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.self.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attention.self.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + + policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[BertLayer].sub_module_replacement.append( + # Handle bert layer + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="attention.output.LayerNorm", target_module=col_nn.FusedLayerNorm, - )) - base_policy[BertLayer].sub_module_replacement.append( + ), SubModuleReplacementDescription( suffix="output.LayerNorm", target_module=col_nn.FusedLayerNorm, - )) - base_policy[BertEmbeddings].sub_module_replacement.append( - SubModuleReplacementDescription( + ) + ], + policy=policy, + target_key=BertLayer) + + # handle embedding layer + self.append_or_create_submodule_replacement( + description=[SubModuleReplacementDescription( suffix="LayerNorm", target_module=col_nn.FusedLayerNorm, - ),) + )], + policy=policy, + target_key=BertEmbeddings) + return policy + + def add_lm_head_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + + # optimize for tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + policy=base_policy, + target_key=BertLMPredictionHead) + + # optimize with fused normalization + if self.shard_config.enable_fused_normalization: + # Handle bert lm prediction head + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + policy=base_policy, + target_key=BertLMPredictionHead) return base_policy def postprocess(self): @@ -136,35 +161,14 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.bert.modeling_bert import BertLMPredictionHead - module_policy = super().module_policy() - addon_module = { - BertLMPredictionHead: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - ]) - } - - # optimization configuration - if self.shard_config.enable_fused_normalization: - addon_module[BertLMPredictionHead].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - )) - - # append extra policy - module_policy.update(addon_module) + module_policy = self.add_lm_head_policy(module_policy) return module_policy def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -176,31 +180,14 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.bert.modeling_bert import BertLMPredictionHead - module_policy = super().module_policy() - addon_module = { - BertLMPredictionHead: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - ]) - } - if self.shard_config.enable_fused_normalization: - addon_module[BertLMPredictionHead].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - )) - module_policy.update(addon_module) + module_policy = self.add_lm_head_policy(module_policy) return module_policy def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -212,34 +199,14 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.bert.modeling_bert import BertLMPredictionHead - module_policy = super().module_policy() - addon_module = { - BertLMPredictionHead: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - ]) - } - - # optimization configuration - if self.shard_config.enable_fused_normalization: - addon_module[BertLMPredictionHead].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - )) - - module_policy.update(addon_module) + module_policy = self.add_lm_head_policy(module_policy) return module_policy def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -254,16 +221,18 @@ def module_policy(self): from transformers.models.bert.modeling_bert import BertForSequenceClassification module_policy = super().module_policy() - addon_module = { - BertForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForSequenceClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) return module_policy @@ -277,16 +246,18 @@ def module_policy(self): from transformers.models.bert.modeling_bert import BertForTokenClassification module_policy = super().module_policy() - addon_module = { - BertForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForTokenClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) return module_policy @@ -307,14 +278,16 @@ def module_policy(self): from transformers.models.bert.modeling_bert import BertForMultipleChoice module_policy = super().module_policy() - addon_module = { - BertForMultipleChoice: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForMultipleChoice: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) return module_policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 00ab9159b0dc..030774a919d7 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -85,57 +85,53 @@ def preprocess(self): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel - base_policy = { - BloomBlock: - ModulePolicyDescription( - attribute_replacement={ - # 1. shard hidden size - "self_attention.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - # 2. shard number of heads - "self_attention.num_heads": - self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=col_nn.Linear1D_Row, - ), - ]), - BloomModel: - ModulePolicyDescription(attribute_replacement={ + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[BloomModel] = ModulePolicyDescription( + attribute_replacement={ "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, }, - method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) - ]) - } + method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[BloomModel].sub_module_replacement.extend([ + # handle bloom model + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="ln_f", target_module=col_nn.FusedLayerNorm, @@ -144,8 +140,12 @@ def module_policy(self): suffix="word_embeddings_layernorm", target_module=col_nn.FusedLayerNorm, ) - ]) - base_policy[BloomBlock].sub_module_replacement.extend([ + ], + policy=policy, + target_key=BloomModel) + + # handle bloom block + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="input_layernorm", target_module=col_nn.FusedLayerNorm, @@ -154,9 +154,11 @@ def module_policy(self): suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm, ) - ]) + ], + policy=policy, + target_key=BloomBlock) - return base_policy + return policy def postprocess(self): return self.model @@ -171,19 +173,19 @@ class BloomForCausalLMPolicy(BloomPolicy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForCausalLM policy = super().module_policy() - # add a new item for casual lm - new_item = { - BloomForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=BloomForCausalLM) + return policy def postprocess(self): binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} + for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -191,7 +193,6 @@ def postprocess(self): param = nn.Parameter(param) # tie weights - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -201,15 +202,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification policy = super().module_policy() - # add a new item for casual lm - new_item = { - BloomForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=BloomForSequenceClassification) + return policy @@ -218,19 +218,21 @@ class BloomForTokenClassificationPolicy(BloomPolicy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForTokenClassification policy = super().module_policy() - # add a new item for casual lm - new_item = { - BloomForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ]) - } - policy.update(new_item) + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification) + return policy diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index ad0b1144a8a5..549cdbf87a80 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -31,67 +31,67 @@ def preprocess(self): def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model - base_policy = { - GPT2Model: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), - ]), - GPT2Block: - ModulePolicyDescription(attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) - } + policy = {} - # optimization configuration - if self.shard_config.enable_fused_normalization: - base_policy[GPT2Model].sub_module_replacement.append( + if self.shard_config.enable_tensor_parallelism: + policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - )) + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]) + policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) - base_policy[GPT2Block].sub_module_replacement.extend([ + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + policy=policy, + target_key=GPT2Model) + + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="ln_1", target_module=col_nn.FusedLayerNorm, @@ -103,9 +103,10 @@ def module_policy(self): SubModuleReplacementDescription(suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True) - ]) - - return base_policy + ], + policy=policy, + target_key=GPT2Block) + return policy def postprocess(self): return self.model @@ -128,22 +129,22 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel module_policy = super().module_policy() - addon_module = { - GPT2LMHeadModel: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) return module_policy def postprocess(self): binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -158,22 +159,22 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel module_policy = super().module_policy() - addon_module = { - GPT2DoubleHeadsModel: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2DoubleHeadsModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) return module_policy def postprocess(self): binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 8f397693745c..157785bdcf13 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -28,58 +28,58 @@ def preprocess(self): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel - base_policy = { - LlamaDecoderLayer: - ModulePolicyDescription( - attribute_replacement={ - "self_attn.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=Linear1D_Row, - ) - ], - ), - LlamaModel: - ModulePolicyDescription(sub_module_replacement=[ + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, ) - ]) - } + ], + ) + + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[LlamaDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="input_layernorm", target_module=FusedRMSNorm, @@ -88,15 +88,18 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: suffix="post_attention_layernorm", target_module=FusedRMSNorm, ) - ]) + ], + policy=policy, + target_key=LlamaDecoderLayer) - base_policy[LlamaModel].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, - )) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=LlamaModel) - return base_policy + return policy def postprocess(self): return self.model @@ -108,15 +111,17 @@ def module_policy(self): from transformers import LlamaForCausalLM policy = super().module_policy() - # add a new item for casual lm - new_item = { - LlamaForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) return policy @@ -127,13 +132,14 @@ def module_policy(self): policy = super().module_policy() - # add a new item for sequence classification - new_item = { - LlamaForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 428ee2c9776c..b87db53f45f1 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -29,66 +29,67 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - base_policy = { - OPTDecoder: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), - OPTDecoderLayer: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), - OPTAttention: - ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True)) - base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ]) + ], + policy=policy, + target_key=OPTDecoderLayer) - return base_policy + return policy def postprocess(self): return self.model @@ -106,15 +107,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - new_item = { - OPTForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 37fccaabc457..cde59ab77042 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -42,116 +42,126 @@ def module_policy(self): T5Stack, ) - base_policy = { - T5Stack: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=Embedding1D, - ) - ]), - T5LayerSelfAttention: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]), - T5LayerCrossAttention: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]), - T5Attention: - ModulePolicyDescription(attribute_replacement={ - "d_model": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "n_heads": - self.model.config.num_heads // self.shard_config.tensor_parallel_size, - "inner_dim": - self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="o", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription(suffix="relative_attention_bias", - target_module=Embedding1D, - kwargs=dict(gather_output=False), - ignore_if_not_exist=True) - ]), - T5LayerFF: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]), - T5DenseGatedActDense: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi_0", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wi_1", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]), - T5DenseActDense: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wo", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]) + policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]) + policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) + policy[T5Attention] = ModulePolicyDescription(attribute_replacement={ + "d_model": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": + self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": + self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True) + ]) + policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]) + policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) + policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[T5LayerFF].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) - base_policy[T5LayerSelfAttention].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) - base_policy[T5LayerCrossAttention].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) - base_policy[T5Stack].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm)) - - return base_policy + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerSelfAttention) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerCrossAttention) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5Stack) + return policy def postprocess(self): binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] @@ -166,14 +176,15 @@ class T5ModelPolicy(T5BasePolicy): def module_policy(self): from transformers import T5Model - base_policy = super().module_policy() - base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="shared", target_module=VocabParallelEmbedding1D, - ) - ]) + ), + policy=base_policy, + target_key=T5Model) return base_policy @@ -183,14 +194,19 @@ def module_policy(self): from transformers import T5ForConditionalGeneration policy = super().module_policy() - policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ], + policy=policy, + target_key=T5ForConditionalGeneration) return policy def postprocess(self): @@ -212,12 +228,14 @@ def module_policy(self): from transformers import T5EncoderModel base_policy = super().module_policy() - base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="shared", target_module=VocabParallelEmbedding1D, - ) - ]) + ), + policy=base_policy, + target_key=T5EncoderModel) return base_policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0a5aa4cc4bdc..83c08d275df3 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,11 +13,12 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False @@ -33,8 +34,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index aa1424af3289..d83d9ecd39e0 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,12 +3,13 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn): +def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=True) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 87c4ef65bf1a..1afedb7079ea 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -3,7 +3,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -33,36 +40,50 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # compare self attention grad org_grad = bert.encoder.layer[0].attention.self.query.weight.grad shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad + shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # compare embedding grad org_grad = bert.embeddings.word_embeddings.weight.grad shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad + shard_weight = sharded_bert.embeddings.word_embeddings.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_bert(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 70d902a04517..a3389652269c 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,7 +3,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -32,10 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = bloom.h[0].self_attention.query_key_value.weight.grad shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad + shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" @@ -43,27 +54,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check embedding weights org_grad = bloom.word_embeddings.weight.grad shard_grad = sharded_bloom.word_embeddings.weight.grad + shard_weight = sharded_bloom.word_embeddings.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_bloom(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index a4edc14bdbc3..ee7737687d99 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,7 +3,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -32,11 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check mlp grad org_grad = org_model.h[0].mlp.c_fc.weight.grad shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad + shard_weight = sharded_model.h[0].mlp.c_fc.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=1) - + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=1) + else: + all_shard_grad = shard_grad assert torch.allclose( org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" @@ -44,27 +54,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check embedding weights org_grad = org_model.wte.weight.grad shard_grad = sharded_model.wte.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = sharded_model.wte.weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose( org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" -def check_gpt2(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_gpt2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a98743a6143a..74b5fdd18af8 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,7 +5,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -37,35 +44,48 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" # check embedding grad org_grad = llama_model.embed_tokens.weight.grad shard_grad = shard_llama_model.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_llama_model.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_llama() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 29cf2f6beed8..25bccb13b1a8 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,10 +6,11 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, - check_state_dict_equal, clear_cache_before_run, + parameterize, rerun_if_address_is_in_use, spawn, ) @@ -42,34 +43,48 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_OPTModel(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_OPTModel(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 91430bce918f..0762dc09e5af 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -5,7 +5,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -27,19 +34,28 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad - - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" # check self attention embed org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=1) + shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=1) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" @@ -52,25 +68,34 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() shard_grad = sharded_model.shared.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = sharded_model.shared.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_t5(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() +def check_t5(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()