Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .dropout import Dropout1D
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import LayerNorm1D
from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d

__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d"
"Dropout1D", "cross_entropy_1d", 'LayerNorm1D'
]
89 changes: 89 additions & 0 deletions colossalai/shardformer/layer/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import List, Union

import torch
import torch.nn as nn
from torch.distributed import ProcessGroup

from colossalai.kernel import LayerNorm
from colossalai.nn import init as init

from .parallel_module import ParallelModule

__all__ = ['LayerNorm1D']

Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass


class LayerNorm1D(ParallelModule):
r"""
Layer Normalization for colossalai

Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""

_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]

def __init__(self,
normalized_shape: int,
eps: int = 1e-05,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype)
self.norm = norm

@staticmethod
def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module
"""
normalized_shape = module.normalized_shape
eps = module.eps
bias = module.bias is not None
dtype = module.weight.dtype
device = module.weight.device

# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

# create layer norm
layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm

with torch.no_grad():
# copy weight and bias
layer_norm.weight.copy_(module.weight)
if bias:
layer_norm.bias.copy_(module.bias)
return layer_norm
122 changes: 110 additions & 12 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertForMultipleChoice,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMPredictionHead,
)

import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
Expand All @@ -24,7 +30,7 @@ def preprocess(self):
return self.model

def module_policy(self):
return {
base_policy = {
BertLayer:
ModulePolicyDescription(
attribute_replacement={
Expand Down Expand Up @@ -53,10 +59,18 @@ def module_policy(self):
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.Dropout1D,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
Expand All @@ -66,12 +80,8 @@ def module_policy(self):
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
suffix="output.dropout",
target_module=col_nn.Dropout1D,
)
]),
BertEmbeddings:
Expand All @@ -81,10 +91,32 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}

if self.shard_config.fused_layernorm:
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
base_policy[BertEmbeddings].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.LayerNorm1D,
),)
return base_policy

def new_model_class(self):
# do nothing
return self.model
Expand Down Expand Up @@ -115,9 +147,15 @@ def module_policy(self):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy

Expand Down Expand Up @@ -146,9 +184,15 @@ def module_policy(self):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy

Expand Down Expand Up @@ -177,9 +221,15 @@ def module_policy(self):
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
kwargs={"gather_output": True}),
])
}
if self.shard_config.fused_layernorm:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
))
module_policy.update(addon_module)
return module_policy

Expand All @@ -199,13 +249,45 @@ class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy


# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy


# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
Expand All @@ -219,3 +301,19 @@ class BertForMultipleChoicePolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.Dropout1D,
)
])
}
module_policy.update(addon_module)
return module_policy
4 changes: 3 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ class ShardConfig:
The config for sharding the huggingface model

Args:
data_parallel_size (int): The size of data parallel
tensor_parallel_size (int): The size of tensor parallel
use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm`
data_parallel_size (int): The size of data parallel
pipeline_parallel_size (int): The size of pipeline parallel
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
"""
tensor_parallel_size: int
fused_layernorm: bool = False

# TODO: add support for tensor parallel
# pipeline_parallel_size: int
Expand Down
45 changes: 45 additions & 0 deletions tests/test_shardformer/test_layer/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close

import colossalai
from colossalai.shardformer.layer import LayerNorm1D
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_layernorm_1d():
norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = LayerNorm1D.from_native_module(norm, process_group=None)

assert norm1d.weight.shape == torch.Size([128])

# ensure state dict is reversibly loadable
norm.load_state_dict(norm1d.state_dict())
norm1d.load_state_dict(norm.state_dict())

# check computation correctness
x = torch.rand(4, 128).cuda()
out = norm(x)
gather_out = norm1d(x)
assert_close(out, gather_out)

# check backward correctness
out.sum().backward()
gather_out.sum().backward()

assert_close(norm.weight.grad, norm1d.weight.grad)


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_layernorm_1d()


@rerun_if_address_is_in_use()
def test_layernorm_1d():
spawn(run_dist, nprocs=2)


if __name__ == '__main__':
test_layernorm_1d()
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_layer/test_linearconv_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def check_linear_conv_1d_col():
assert_close(target_grad, linear_conv_col.weight.grad)


def check_linear_1d_row():
def check_linear_conv_1d_row():
linear = Conv1D(192, 48).cuda()
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)

Expand All @@ -103,7 +103,7 @@ def check_linear_1d_row():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_conv_1d_col()
check_linear_1d_row()
check_linear_conv_1d_row()


@rerun_if_address_is_in_use()
Expand Down
Loading