Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row

__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm'
'FusedLayerNorm', 'FusedRMSNorm'
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

__all__ = ['FusedLayerNorm']
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']

FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
Expand Down Expand Up @@ -61,4 +61,44 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
# copy weight and bias
layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias)
return layernorm
return layernorm


class FusedRMSNorm():
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""

def __init__(self) -> None:
raise NotImplementedError(
'FusedRMSNorm is not implemented as a physical class. '
'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.'
)

@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError:
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
)

# to check if it is huggingface LlamaRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm":
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine

rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)

with torch.no_grad():
# copy weight and bias
rmsnorm.weight.copy_(module.weight)

return rmsnorm
8 changes: 8 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def set_shard_config(self, shard_config: ShardConfig) -> None:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()

@abstractmethod
def config_sanity_check(self):
"""
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
"""
pass

@abstractmethod
def preprocess(self) -> nn.Module:
Expand Down
21 changes: 16 additions & 5 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

class BertPolicy(Policy):

def config_sanity_check(self):
pass

def preprocess(self):
# reshape the embedding layer
r"""
Expand Down Expand Up @@ -99,7 +102,8 @@ def module_policy(self):
])
}

if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
Expand Down Expand Up @@ -150,12 +154,16 @@ def module_policy(self):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:

# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))

# append extra policy
module_policy.update(addon_module)
return module_policy

Expand Down Expand Up @@ -187,7 +195,7 @@ def module_policy(self):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
Expand Down Expand Up @@ -224,12 +232,15 @@ def module_policy(self):
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:

# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm,
))

module_policy.update(addon_module)
return module_policy

Expand Down Expand Up @@ -316,4 +327,4 @@ def module_policy(self):
])
}
module_policy.update(addon_module)
return module_policy
return module_policy
31 changes: 29 additions & 2 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,

class BloomPolicy(Policy):

def config_sanity_check(self):
pass

def preprocess(self):
# reshape the embedding layer
r"""
Expand All @@ -81,7 +84,7 @@ def preprocess(self):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel

return {
base_policy = {
BloomBlock:
ModulePolicyDescription(
attribute_replacement={
Expand All @@ -99,7 +102,6 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
# kwargs={'n_fused': 3}
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
Expand Down Expand Up @@ -132,6 +134,31 @@ def module_policy(self):
])
}

# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BloomModel].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
base_policy[BloomBlock].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])

return base_policy

def new_model_class(self):
# do nothing
return self.model
Expand Down
29 changes: 28 additions & 1 deletion colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

class GPT2Policy(Policy):

def config_sanity_check(self):
pass

def preprocess(self):
# reshape the embedding layer
r"""
Expand All @@ -22,7 +25,7 @@ def preprocess(self):
return self.model

def module_policy(self):
return {
base_policy = {
GPT2Model:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down Expand Up @@ -77,6 +80,30 @@ def module_policy(self):
])
}

# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[GPT2Model].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
))

base_policy[GPT2Block].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(suffix="ln_cross_attn",
target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True)
])

return base_policy

def new_model_class(self):
return self.model

Expand Down
28 changes: 26 additions & 2 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class LlamaPolicy(Policy):

def config_sanity_check(self):
pass

def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
Expand All @@ -23,7 +26,7 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
Expand Down Expand Up @@ -75,6 +78,27 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
])
}

# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
)
])

base_policy[LlamaModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
))

return base_policy

def new_model_class(self):
return None

Expand Down
8 changes: 7 additions & 1 deletion colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

class OPTPolicy(Policy):

def config_sanity_check(self):
pass

def preprocess(self):
# reshape the embedding layer
r"""
Expand Down Expand Up @@ -74,7 +77,9 @@ def module_policy(self):
),
]),
}
if self.shard_config.fused_layernorm:

# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[OPTDecoder].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm,
Expand All @@ -87,6 +92,7 @@ def module_policy(self):
target_module=FusedLayerNorm,
ignore_if_not_exist=True)
])

return base_policy

def new_model_class(self):
Expand Down
Loading