diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 54cc63ba124f..af828b46ce25 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -10,16 +10,41 @@ def build_policies(): """ auto_policy_dict = {} + from transformers import BertModel + + from .bert import BertModelPolicy + auto_policy_dict[BertModel] = BertModelPolicy + + from transformers import BertForPreTraining + + from .bert import BertForPretrainingPolicy + auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy + + from transformers import BertLMHeadModel + + from .bert import BertLMHeadModelPolicy + auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy + from transformers import BertForMaskedLM from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy + from transformers import BertForNextSentencePrediction + + from .bert import BertForNextSentencePredictionPolicy + auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy + from transformers import BertForSequenceClassification from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + from transformers import BertForMultipleChoice + + from .bert import BertForMultipleChoicePolicy + auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy + from transformers import GPT2Model from .gpt2 import GPT2Policy diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 67e910d521e9..ba2266353e3e 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -35,12 +35,6 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: ]), } - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - @staticmethod def attn_in(): return [ @@ -148,9 +142,53 @@ def embedding(): replace_layer=col_nn.VocabParallelEmbedding1D, )] + @staticmethod + def unembedding(): + return [ + Col_Layer( + suffix="decoder", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) + ] + + +# BertModel +class BertModelPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + -from transformers import BertForMaskedLM +# BertForPretraining +class BertForPretrainingPolicy(BertPolicy): + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def inject_policy(): + return None + + @staticmethod + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + + +# BertForMaskedLM from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ @@ -161,7 +199,7 @@ def argument_policy(config, world_size): base_argument = BertPolicy.argument_policy(config, world_size) argument = { BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertForMaskedLMPolicy.unembedding, + BertPolicy.unembedding, ]), } argument.update(base_argument) @@ -173,20 +211,56 @@ def inject_policy(): return None @staticmethod - def unembedding(): - return [ - Col_Layer( - suffix="decoder", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ) - ] + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } -class BertForSequenceClassificationPolicy(BertPolicy): +# BertLMHeadModel +class BertLMHeadModelPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument @staticmethod def inject_policy(): return None + + @staticmethod + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + + +# BertForMultipleChoice +class BertForMultipleChoicePolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index e8d6f3408c76..96c287577ddc 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,6 +13,6 @@ class ShardConfig: world_size (int): The world size of the distributed process gather_output (bool): Whether to gather the output of the model of the last layer """ - rank: int - world_size: int = 2 + rank: int = None + world_size: int = None gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 95184cfe6929..36b89c1c4acb 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -271,6 +271,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding """ + # TODO: init shard_config automatically sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) sharder.shard() return model diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 55b78d040505..9b29111eadb2 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,9 +1,19 @@ +import copy import os -import random import pytest import torch -from transformers import AutoTokenizer, BertConfig, BertForMaskedLM +from transformers import ( + AutoTokenizer, + BertConfig, + BertForMaskedLM, + BertForMultipleChoice, + BertForNextSentencePrediction, + BertForPreTraining, + BertForSequenceClassification, + BertLMHeadModel, + BertModel, +) import colossalai from colossalai.logging import disable_existing_loggers @@ -15,20 +25,21 @@ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") -def build_model(rank, world_size): +def build_model(rank, world_size, model): config = BertConfig.from_pretrained('bert-base-uncased') config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 - org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') + org_model = model(config=config) + org_model_forshard = copy.deepcopy(org_model) + org_model = org_model.to('cuda') shardconfig = ShardConfig( rank=rank, world_size=world_size, gather_output=True, ) - sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), - shardconfig).to('cuda') + sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') return org_model, sharded_model @@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model): def check_bert(rank, world_size, port): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - org_model, sharded_model = build_model(rank, world_size) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) - - torch.cuda.empty_cache() + forward_list = [ + BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction, + BertForSequenceClassification + ] + backward_lsit = [BertForMaskedLM, BertLMHeadModel] + + for model in forward_list: + org_model, sharded_model = build_model(rank, world_size, model) + check_forward(org_model, sharded_model) + if model in backward_lsit: + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() @pytest.mark.dist