Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .dropout import Dropout1D
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import LayerNorm1D
from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d

__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d", 'LayerNorm1D'
"Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
]
101 changes: 38 additions & 63 deletions colossalai/shardformer/layer/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,64 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import List, Union

import torch
import torch.nn as nn
from torch.distributed import ProcessGroup

from colossalai.kernel import LayerNorm
from colossalai.nn import init as init

from .parallel_module import ParallelModule
__all__ = ['FusedLayerNorm']

__all__ = ['LayerNorm1D']
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
25600, 30720, 32768, 40960, 49152, 65536
]

Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass


class LayerNorm1D(ParallelModule):
class FusedLayerNorm():
r"""
Layer Normalization for colossalai

Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""

_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]

def __init__(self,
normalized_shape: int,
eps: int = 1e-05,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype)
self.norm = norm
def __init__(self) -> None:
raise NotImplementedError(
'FusedLayerNorm is not implemented as a physical class. '
'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.'
)

@staticmethod
def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module
"""
# check if apex is installed
try:
import apex
except ImportError:
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')

# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
bias = module.bias is not None
elementwise_affine = module.elementwise_affine
dtype = module.weight.dtype
device = module.weight.device

# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
# pick the suitable layernorm implementation
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE

if use_fast_ln:
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm
except ImportError:
# fall back to the normal fused layernorm is not built
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
else:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm

# create layer norm
layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm
layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps,
elementwise_affine=elementwise_affine).to(dtype).to(device)

with torch.no_grad():
# copy weight and bias
layer_norm.weight.copy_(module.weight)
if bias:
layer_norm.bias.copy_(module.bias)
return layer_norm
layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias)
return layernorm
12 changes: 6 additions & 6 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ def module_policy(self):
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
base_policy[BertEmbeddings].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
),)
return base_policy

Expand Down Expand Up @@ -154,7 +154,7 @@ def module_policy(self):
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy
Expand Down Expand Up @@ -191,7 +191,7 @@ def module_policy(self):
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy
Expand Down Expand Up @@ -228,7 +228,7 @@ def module_policy(self):
addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D,
target_module=col_nn.FusedLayerNorm,
))
module_policy.update(addon_module)
return module_policy
Expand Down
11 changes: 5 additions & 6 deletions tests/test_shardformer/test_layer/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close

import colossalai
from colossalai.shardformer.layer import LayerNorm1D
from colossalai.shardformer.layer import FusedLayerNorm
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_layernorm_1d():
def check_layernorm():
norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = LayerNorm1D.from_native_module(norm, process_group=None)
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)

assert norm1d.weight.shape == torch.Size([128])

Expand All @@ -33,11 +32,11 @@ def check_layernorm_1d():

def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_layernorm_1d()
check_layernorm()


@rerun_if_address_is_in_use()
def test_layernorm_1d():
def test_layernorm():
spawn(run_dist, nprocs=2)


Expand Down