Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions colossalai/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def set_seed(self):
if hasattr(self.config, 'seed'):
seed = getattr(self.config, 'seed')
else:
seed = 2 # default seed
seed = 1024 # default seed

random.seed(seed)
np.random.seed(seed)
Expand All @@ -426,15 +426,18 @@ def set_seed(self):
if torch.cuda.is_available():
# create random seed for different parallel modes
# data parallel seed are kept the same
tp_rank = self._local_ranks.get(ParallelMode.TENSOR, 0)
pp_rank = self._local_ranks.get(ParallelMode.PIPELINE, 0)
parallel_seed = seed + tp_rank + pp_rank * 1024
parallel_seed = seed
add_seed(ParallelMode.DATA, parallel_seed)

# model parallel seeds are different across ranks
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)

# add seed for data parallel and tensor parallel only
if self.is_initialized(ParallelMode.TENSOR):
dp_rank = self._local_ranks.get(ParallelMode.DATA, 0) + 1
tp_seed = parallel_seed + dp_rank * 128
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
# 100 is only to increase the diff in seeds between pipeline stages
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
tp_seed = seed + tp_rank_with_offset
add_seed(ParallelMode.TENSOR, tp_seed)

set_mode(ParallelMode.DATA)
Expand Down
1 change: 1 addition & 0 deletions colossalai/nn/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
Expand Down
35 changes: 35 additions & 0 deletions colossalai/nn/layer/fused_bias_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py

import torch

@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g

class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)

@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp

bias_gelu_impl = GeLUFunction.apply
4 changes: 2 additions & 2 deletions colossalai/nn/layer/parallel_1d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHeadNormal, ViTSelfAttention1DV2
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHeadNormal



__all__ = [
'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHeadNormal', 'ViTSelfAttention1DV2'
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
]
34 changes: 34 additions & 0 deletions colossalai/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None


class FusedLayerNormAffineFunction1D(torch.autograd.Function):

@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output


@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)

return grad_input, grad_weight, grad_bias, None, None
Loading