Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,28 @@ def postprocess(self) -> nn.Module:
the classifier layer
"""
pass

def append_or_create_submodule_replacement(
self, description: Union[SubModuleReplacementDescription,
List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module],
ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new submodule replacement description to the policy for the given key.

Args:
submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
# convert to list
if isinstance(description, SubModuleReplacementDescription):
description = [description]

# append or create a new description
if target_key in policy:
policy[target_key].sub_module_replacement.extend(description)
else:
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)

return policy
299 changes: 136 additions & 163 deletions colossalai/shardformer/policies/bert.py

Large diffs are not rendered by default.

166 changes: 84 additions & 82 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,57 +85,53 @@ def preprocess(self):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel

base_policy = {
BloomBlock:
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"self_attention.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads
"self_attention.num_heads":
self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
]),
BloomModel:
ModulePolicyDescription(attribute_replacement={
policy = {}

if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
])

policy[BloomModel] = ModulePolicyDescription(
attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
}
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])

# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BloomModel].sub_module_replacement.extend([
# handle bloom model
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
Expand All @@ -144,8 +140,12 @@ def module_policy(self):
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
base_policy[BloomBlock].sub_module_replacement.extend([
],
policy=policy,
target_key=BloomModel)

# handle bloom block
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
Expand All @@ -154,9 +154,11 @@ def module_policy(self):
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
],
policy=policy,
target_key=BloomBlock)

return base_policy
return policy

def postprocess(self):
return self.model
Expand All @@ -171,27 +173,26 @@ class BloomForCausalLMPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)

# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForCausalLM)

return policy

def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}

for k, v in binding_map.items():
param = getattr_(self.model, k)

if not isinstance(param, nn.Parameter):
param = nn.Parameter(param)

# tie weights
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model

Expand All @@ -201,15 +202,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)

# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForSequenceClassification)

return policy


Expand All @@ -218,19 +218,21 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
}
policy.update(new_item)

# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=BloomForTokenClassification)

return policy


Expand Down
Loading