Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class for the example.

def __init__(self) -> None:
self.model = None
self.shard_config = None

def set_model(self, model: nn.Module) -> None:
r"""
Expand All @@ -86,14 +87,23 @@ def set_model(self, model: nn.Module) -> None:
"""
self.model = model

def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.

Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config

@abstractmethod
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
def preprocess(self) -> nn.Module:
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""

@abstractmethod
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Expand Down
92 changes: 73 additions & 19 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,40 @@
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D

from ..shard.shard_config import ShardConfig
from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class BertPolicy(Policy):

def preprocess(self, shard_config: ShardConfig = None):
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = shard_config.tensor_parallel_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self, shard_config: ShardConfig = None):
def module_policy(self):
return {
BertLayer:
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"attention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size,
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size,
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads
"attention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size,
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
Expand Down Expand Up @@ -100,13 +99,43 @@ def postprocess(self):
return self.model


# BertModel
class BertModelPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()


# BertForPreTraining
class BertForPretrainingPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy


# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self, shard_config: ShardConfig = None):
module_policy = super().module_policy(shard_config)
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
Expand All @@ -124,16 +153,41 @@ def module_policy(self, shard_config: ShardConfig = None):
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):

@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
argument.update(base_argument)
return argument
module_policy.update(addon_module)
return module_policy


# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()


# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()


# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()
6 changes: 3 additions & 3 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class ShardConfig:
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
"""
data_parallel_size: int
tensor_parallel_size: int

pipeline_parallel_size: int
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True
gather_output: bool = True
9 changes: 5 additions & 4 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def shard(self) -> None:
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess()
self.replace_model_class()
self.replace_module()
Expand All @@ -57,12 +58,12 @@ def reshape_embedding(self) -> None:
self.model_config = self.model.config

def preprocess(self) -> None:
self.model = self.policy.preprocess(self.shard_config)
self.model = self.policy.preprocess()

def postprocess(self) -> None:
self.model = self.policy.postprocess()

def replace_model_class(self,) -> None:
def replace_model_class(self) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
Expand All @@ -83,14 +84,14 @@ def replace_model_class(self,) -> None:
getattr(new_model_class, key),
)

def replace_module(self,) -> None:
def replace_module(self) -> None:
r"""
Replace the module according to the policy, and replace the module one by one

Args:
model (:class:`torch.nn.Module`): The model to shard
"""
module_descriptions = self.policy.module_policy(self.shard_config)
module_descriptions = self.policy.module_policy()
for module_description in module_descriptions.items():
origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement
Expand Down
4 changes: 0 additions & 4 deletions colossalai/shardformer/shard/shardformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig(
tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
Expand Down
11 changes: 4 additions & 7 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
Expand Down Expand Up @@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_config = ShardConfig(
tensor_parallel_size=2,
tensor_parallel_mode='1d',
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
Expand Down