Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,41 @@ def build_policies():
"""
auto_policy_dict = {}

from transformers import BertModel

from .bert import BertModelPolicy
auto_policy_dict[BertModel] = BertModelPolicy

from transformers import BertForPreTraining

from .bert import BertForPretrainingPolicy
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy

from transformers import BertLMHeadModel

from .bert import BertLMHeadModelPolicy
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy

from transformers import BertForMaskedLM

from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy

from transformers import BertForNextSentencePrediction

from .bert import BertForNextSentencePredictionPolicy
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy

from transformers import BertForSequenceClassification

from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy

from transformers import BertForMultipleChoice

from .bert import BertForMultipleChoicePolicy
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy

from transformers import GPT2Model

from .gpt2 import GPT2Policy
Expand Down
112 changes: 93 additions & 19 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
]),
}

@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}

@staticmethod
def attn_in():
return [
Expand Down Expand Up @@ -148,9 +142,53 @@ def embedding():
replace_layer=col_nn.VocabParallelEmbedding1D,
)]

@staticmethod
def unembedding():
return [
Col_Layer(
suffix="decoder",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]


# BertModel
class BertModelPolicy(BertPolicy):

@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)


from transformers import BertForMaskedLM
# BertForPretraining
class BertForPretrainingPolicy(BertPolicy):

@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument

@staticmethod
def inject_policy():
return None

@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}


# BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_


Expand All @@ -161,7 +199,7 @@ def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertForMaskedLMPolicy.unembedding,
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
Expand All @@ -173,20 +211,56 @@ def inject_policy():
return None

@staticmethod
def unembedding():
return [
Col_Layer(
suffix="decoder",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}


class BertForSequenceClassificationPolicy(BertPolicy):
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):

@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument

@staticmethod
def inject_policy():
return None

@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}


# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):

@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)


# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):

@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)


# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):

@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
4 changes: 2 additions & 2 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ class ShardConfig:
world_size (int): The world size of the distributed process
gather_output (bool): Whether to gather the output of the model of the last layer
"""
rank: int
world_size: int = 2
rank: int = None
world_size: int = None
gather_output: bool = True
1 change: 1 addition & 0 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
# TODO: init shard_config automatically
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard()
return model
42 changes: 30 additions & 12 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import copy
import os
import random

import pytest
import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM
from transformers import (
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
BertLMHeadModel,
BertModel,
)

import colossalai
from colossalai.logging import disable_existing_loggers
Expand All @@ -15,20 +25,21 @@
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


def build_model(rank, world_size):
def build_model(rank, world_size, model):
config = BertConfig.from_pretrained('bert-base-uncased')
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0

org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda')
org_model = model(config=config)
org_model_forshard = copy.deepcopy(org_model)

org_model = org_model.to('cuda')
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config),
shardconfig).to('cuda')
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')

return org_model, sharded_model

Expand Down Expand Up @@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model):
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)

torch.cuda.empty_cache()
forward_list = [
BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction,
BertForSequenceClassification
]
backward_lsit = [BertForMaskedLM, BertLMHeadModel]

for model in forward_list:
org_model, sharded_model = build_model(rank, world_size, model)
check_forward(org_model, sharded_model)
if model in backward_lsit:
check_backward(org_model, sharded_model)

torch.cuda.empty_cache()


@pytest.mark.dist
Expand Down