diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 4ab01c3eaff0..1a4d8bd432bc 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -415,7 +415,7 @@ def set_seed(self): if hasattr(self.config, 'seed'): seed = getattr(self.config, 'seed') else: - seed = 2 # default seed + seed = 1024 # default seed random.seed(seed) np.random.seed(seed) @@ -426,15 +426,18 @@ def set_seed(self): if torch.cuda.is_available(): # create random seed for different parallel modes # data parallel seed are kept the same - tp_rank = self._local_ranks.get(ParallelMode.TENSOR, 0) - pp_rank = self._local_ranks.get(ParallelMode.PIPELINE, 0) - parallel_seed = seed + tp_rank + pp_rank * 1024 + parallel_seed = seed add_seed(ParallelMode.DATA, parallel_seed) + # model parallel seeds are different across ranks + pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0) + # add seed for data parallel and tensor parallel only if self.is_initialized(ParallelMode.TENSOR): - dp_rank = self._local_ranks.get(ParallelMode.DATA, 0) + 1 - tp_seed = parallel_seed + dp_rank * 128 + tp_rank = self.get_local_rank(ParallelMode.TENSOR) + # 100 is only to increase the diff in seeds between pipeline stages + tp_rank_with_offset = tp_rank + pipeline_offset * 1024 + tp_seed = seed + tp_rank_with_offset add_seed(ParallelMode.TENSOR, tp_seed) set_mode(ParallelMode.DATA) diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index 1456a8a56a3f..9d4bb70efa42 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,3 +1,4 @@ +from .fused_bias_gelu import bias_gelu_impl from .parallel_1d import * from .parallel_2d import * from .parallel_2p5d import * diff --git a/colossalai/nn/layer/fused_bias_gelu.py b/colossalai/nn/layer/fused_bias_gelu.py new file mode 100644 index 000000000000..e920415349a0 --- /dev/null +++ b/colossalai/nn/layer/fused_bias_gelu.py @@ -0,0 +1,35 @@ +# adapted from Megatron-LM +# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py + +import torch + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + +bias_gelu_impl = GeLUFunction.apply \ No newline at end of file diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py index 72e6e5a7c82c..cf262053a60b 100644 --- a/colossalai/nn/layer/parallel_1d/__init__.py +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -1,11 +1,11 @@ from .layers import Linear1D_Col, Linear1D_Row from .layers import MixedFusedLayerNorm1D as LayerNorm1D from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D -from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHeadNormal, ViTSelfAttention1DV2 +from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHeadNormal __all__ = [ 'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D', - 'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHeadNormal', 'ViTSelfAttention1DV2' + 'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead' ] diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py new file mode 100644 index 000000000000..aee28926af26 --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -0,0 +1,34 @@ +import torch + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None \ No newline at end of file diff --git a/colossalai/nn/layer/parallel_1d/_vit.py b/colossalai/nn/layer/parallel_1d/_vit.py index 67a6ca2c8157..e45a0e3dfa0d 100644 --- a/colossalai/nn/layer/parallel_1d/_vit.py +++ b/colossalai/nn/layer/parallel_1d/_vit.py @@ -6,6 +6,7 @@ import torch from torch import nn as nn, Tensor, distributed as dist +from torch.nn.init import _calculate_fan_in_and_fan_out from colossalai.context import seed, ParallelMode from colossalai.core import global_context as gpc @@ -15,8 +16,8 @@ from colossalai.utils import checkpoint from colossalai.utils import get_current_device from .layers import Linear1D_Col, Linear1D_Row -from .._common_utils import set_tensor_parallel_attribute from ..base_layer import ParallelLayer +from ..fused_bias_gelu import bias_gelu_impl @LAYERS.register_module @@ -44,7 +45,8 @@ def __init__(self, dropout_prob: float = 0., dtype=None, checkpoint: bool = False, - skip_bias_add: bool = False + skip_bias_add: bool = False, + weight_init='torch' ): super().__init__() @@ -52,39 +54,50 @@ def __init__(self, self.mlp_ratio = mlp_ratio self.checkpoint = checkpoint self.skip_bias_add = skip_bias_add - self.bias = not skip_bias_add + assert weight_init in ('torch', 'jax') + + if act_func == 'fused_gelu': + self.act = bias_gelu_impl + skip_dense_1_add_bias = True + else: + self.act = ACT2FN[act_func] + skip_dense_1_add_bias = False + # Project to mlp_ratio * h. self.dense_1 = Linear1D_Col( self.in_features, int(self.mlp_ratio * self.in_features), - bias=not skip_bias_add, dtype=dtype, - gather_output = False, + gather_output=False, + skip_bias_add=skip_dense_1_add_bias, + init_weight=weight_init, + init_bias=weight_init ) - self.act = ACT2FN[act_func] # Project back to h. self.dense_2 = Linear1D_Row( int(self.mlp_ratio * self.in_features), self.in_features, - bias=not skip_bias_add, dtype=dtype, - parallel_input = True, + parallel_input=True, + init_weight=weight_init, init_bias=weight_init ) self.dropout = nn.Dropout(dropout_prob) def _forward(self, hidden_states: Tensor) -> Tensor: - intermediate_output = self.dense_1(hidden_states) - intermediate_output = self.act(intermediate_output) + if self.act == bias_gelu_impl: + intermediate_output, bias = self.dense_1(hidden_states) + intermediate_output = self.act(intermediate_output, bias) + else: + intermediate_output = self.dense_1(hidden_states) + intermediate_output = self.act(intermediate_output) with seed(ParallelMode.TENSOR): intermediate_output = self.dropout(intermediate_output) output = self.dense_2(intermediate_output) - - with seed(ParallelMode.TENSOR): - output = self.dropout(output) + output = self.dropout(output) return output def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: @@ -121,7 +134,8 @@ def __init__(self, attention_dropout_prob: float, hidden_dropout_prob: float, dtype=None, - checkpoint: bool = False + checkpoint: bool = False, + weight_init='torch' ): super().__init__() @@ -131,11 +145,18 @@ def __init__(self, self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size) self.checkpoint = checkpoint + assert weight_init in ('torch', 'jax') + if weight_init == 'jax': + init_bias = 'zero' + else: + init_bias = weight_init self.query_key_value = Linear1D_Col( hidden_size, 3 * hidden_size, dtype=dtype, + init_weight=weight_init, + init_bias=init_bias ) self.attention_dropout = nn.Dropout(attention_dropout_prob) self.dense = Linear1D_Row( @@ -143,6 +164,7 @@ def __init__(self, hidden_size, dtype=dtype, parallel_input=True, + init_weight=weight_init, init_bias=init_bias ) self.dropout = nn.Dropout(hidden_dropout_prob) self.softmax = nn.Softmax(dim=-1) @@ -172,8 +194,8 @@ def _forward(self, hidden_states: Tensor) -> Tensor: :-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(new_context_layer_shape) output = self.dense(context_layer) - with seed(ParallelMode.TENSOR): - output = self.dropout(output) + output = self.dropout(output) + return output def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: @@ -185,82 +207,6 @@ def forward(self, hidden_states: Tensor) -> Tensor: else: return self._forward(hidden_states) -@LAYERS.register_module -class ViTSelfAttention1DV2(ParallelLayer): - """Self-attention layer for 1D parallel Vision Transformer - - :param hidden_size: hidden size - :type hidden_size: int - :param num_attention_heads: number of attention heads - :type num_attention_heads: int - :param attention_dropout_prob: dropout probability for attention layers - :type attention_dropout_prob: float - :param hidden_dropout_prob: dropout probability for hidden layers - :type hidden_dropout_prob: float - :param dtype: dtype of parameters, defaults to None - :type dtype: torch.dtype, optional - :param checkpoint: whether to checkpoint the layer, defaults to False - :type checkpoint: bool, optional - """ - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - attention_dropout_prob: float, - hidden_dropout_prob: float, - dtype=None, - checkpoint: bool = False - ): - super().__init__() - - self.hidden_size = hidden_size - self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size) - self.attention_head_size = divide(hidden_size, num_attention_heads) - self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size) - - self.checkpoint = checkpoint - - self.query_key_value = Linear1D_Col( - hidden_size, - 3 * hidden_size, - dtype=dtype, - ) - self.attention_dropout = nn.Dropout(attention_dropout_prob) - self.dense = Linear1D_Row( - hidden_size, - hidden_size, - dtype=dtype, - parallel_input=True, - ) - self.dropout = nn.Dropout(hidden_dropout_prob) - self.softmax = nn.Softmax(dim=-1) - self.scale = self.num_attention_heads ** -0.5 - - def _forward(self, x: Tensor) -> Tensor: - B, N, C = x.shape - qkv = self.query_key_value(x).reshape(B, N, 3, self.num_attention_heads, C // - (self.num_attention_heads * gpc.tensor_parallel_size)).permute(2, 0, 3, 1, 4) - # make torchscript happy (cannot use tensor as tuple) - q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - with seed(ParallelMode.TENSOR): - attn = self.attention_dropout(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C // gpc.tensor_parallel_size) - x = self.dense(x) - with seed(ParallelMode.TENSOR): - x = self.dropout(x) - return x - - def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: - return checkpoint(self._forward, hidden_states) - - def forward(self, hidden_states: Tensor) -> Tensor: - if self.checkpoint: - return self._checkpoint_forward(hidden_states) - else: - return self._forward(hidden_states) @LAYERS.register_module class ViTHead1D(ParallelLayer): @@ -278,14 +224,25 @@ def __init__(self, hidden_size, num_classes, dtype=None, + weight_init='torch' ): super().__init__() + assert weight_init in ('torch', 'jax') + if weight_init == 'jax': + init_weight = 'zero' + init_bias = 'zero' + else: + init_weight = weight_init + init_bias = weight_init + self.linear = Linear1D_Col( hidden_size, num_classes, dtype=dtype, - gather_output = True, + gather_output=True, + init_weight=init_weight, + init_bias=init_bias ) def forward(self, x: Tensor) -> Tensor: @@ -294,7 +251,7 @@ def forward(self, x: Tensor) -> Tensor: return x @LAYERS.register_module -class ViTHeadNormal(ParallelLayer): +class ViTHead(ParallelLayer): """Output layer for 1D parallel Vision Transformer :param hidden_size: hidden size @@ -311,7 +268,6 @@ def __init__(self, dtype=None, ): super().__init__() - self.linear = nn.Linear( hidden_size, num_classes, @@ -354,7 +310,8 @@ def __init__(self, patch_size, embed_dim, in_chans=3, - flatten=True): + flatten=True, + weight_init='torch'): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -367,20 +324,20 @@ def __init__(self, self.flatten = flatten self.embed_dim = embed_dim - with seed(ParallelMode.TENSOR): - self.proj = nn.Conv2d(in_chans, - self.embed_dim, - kernel_size=patch_size, - stride=patch_size - ) + self.proj = nn.Conv2d(in_chans, + self.embed_dim, + kernel_size=patch_size, + stride=patch_size + ) + + if weight_init == 'jax': + fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight) + std = math.sqrt(1.0 / fan_in) + nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978) + nn.init.zeros_(self.proj.bias) # sync self._broadcast_conv_params() - self._set_tensor_parallel_attribute() - - def _set_tensor_parallel_attribute(self): - set_tensor_parallel_attribute(self.proj.weight) - set_tensor_parallel_attribute(self.proj.bias) def _broadcast_conv_params(self) -> None: self.to(get_current_device()) @@ -435,29 +392,20 @@ def __init__(self, self.cls_token = nn.Parameter(torch.zeros( 1, 1, self.embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros( + self.pos_embed = nn.Parameter(torch.empty( 1, self.num_patches + 1, self.embed_dim)) + nn.init.trunc_normal_(self.pos_embed, std=.02) # move to cuda before broadcast self.to(get_current_device()) - - # sync param in both forward and backward - _cls_token = self.cls_token.view(-1) - _pos_embed = self.pos_embed.view(-1) - self._param = torch.cat([_cls_token, _pos_embed], dim=0) - + dist.broadcast(self.pos_embed, + src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + group=gpc.get_group(ParallelMode.TENSOR)) self.pos_drop = nn.Dropout(p=drop_rate) - self._set_tensor_parallel_attribute() - - def _set_tensor_parallel_attribute(self): - set_tensor_parallel_attribute(self.cls_token) - set_tensor_parallel_attribute(self.pos_embed) def forward(self, x: Tensor) -> Tensor: - # stole cls_tokens impl from Phil Wang, thanks cls_token = self.cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_token, x), dim=1) - with seed(ParallelMode.TENSOR): - x = self.pos_drop(x + self.pos_embed) + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) return x.contiguous() diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index b67d0a2c1f42..6158da07a634 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -4,6 +4,7 @@ import math import numbers import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init @@ -16,14 +17,12 @@ from colossalai.core import global_context as gpc from colossalai.registry import LAYERS from colossalai.utils import get_current_device +from ._operation import FusedLayerNormAffineFunction1D from .._common_utils import divide, set_tensor_parallel_attribute from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \ split_forward_gather_backward from ..base_layer import ParallelLayer -global fused_mix_prec_layer_norm_cuda -fused_mix_prec_layer_norm_cuda = None - @LAYERS.register_module class Linear1D_Col(ParallelLayer): @@ -51,14 +50,21 @@ def __init__(self, output_size: int, bias: bool = True, dtype: torch.dtype = None, - gather_output: bool = False): + gather_output: bool = False, + skip_bias_add: bool = False, + init_weight='torch', + init_bias='torch' + ): super().__init__() # Keep input parameters self.in_features = in_features self.out_features = output_size self.gather_output = gather_output - self.skip_bias_add = not bias + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size) @@ -78,25 +84,39 @@ def __init__(self, self.bias.zero_() else: self.register_parameter('bias', None) - self.reset_parameters() + with seed(ParallelMode.TENSOR): + self.reset_parameters(init_weight, init_bias) self._set_tensor_parallel_attributes() - def reset_parameters(self) -> None: - fan_in = self.in_features - a = math.sqrt(5) - nonlinearity = 'leaky_relu' + def reset_parameters(self, init_weight, init_bias) -> None: + assert init_weight in ('torch', 'jax', 'zero') + assert init_bias in ('torch', 'jax', 'zero') + # setting + fan_in, fan_out = self.in_features, self.out_features # init weight - std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) - bound = math.sqrt(3.0) * std - with seed(ParallelMode.TENSOR): - nn.init.uniform_(self.weight, -bound, bound) + if init_weight == 'torch': + a = math.sqrt(5) + nonlinearity = 'leaky_relu' + std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) + bound = math.sqrt(3.0) * std + init.uniform_(self.weight, -bound, bound) + elif init_weight == 'jax': + std = math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std + init.uniform_(self.weight, -a, a) + elif init_weight == 'zero': + init.zeros_(self.weight) # init bias - if not self.skip_bias_add: - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - with seed(ParallelMode.TENSOR): - nn.init.uniform_(self.bias, -bound, bound) + if self.bias is not None: + if init_bias == 'torch': + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + elif init_bias == 'jax': + init.normal_(self.bias, std=1e-6) + elif init_bias == 'zero': + init.zeros_(self.bias) def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute(self.weight) @@ -143,7 +163,10 @@ def __init__(self, out_features: int, bias: bool = True, dtype: torch.dtype = None, - parallel_input: bool = False + parallel_input: bool = False, + skip_bias_add: bool = False, + init_weight='torch', + init_bias='torch' ): super().__init__() @@ -151,7 +174,10 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.parallel_input = parallel_input - self.skip_bias_add = not bias + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') # Divide the weight matrix along the last dimension. self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) @@ -175,31 +201,46 @@ def __init__(self, self.bias.zero_() else: self.register_parameter('bias', None) - self.reset_parameters() + with seed(ParallelMode.TENSOR): + self.reset_parameters(init_weight, init_bias) self._set_tensor_parallel_attributes() - def reset_parameters(self) -> None: + def reset_parameters(self, init_weight, init_bias) -> None: + assert init_weight in ('torch', 'jax', 'zero') + assert init_bias in ('torch', 'jax', 'zero') # setting - fan_in = self.in_features - a = math.sqrt(5) - nonlinearity = 'leaky_relu' + fan_in, fan_out = self.in_features, self.out_features # init weight - std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) - bound = math.sqrt(3.0) * std - with seed(ParallelMode.TENSOR): - nn.init.uniform_(self.weight, -bound, bound) + if init_weight == 'torch': + a = math.sqrt(5) + nonlinearity = 'leaky_relu' + std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) + bound = math.sqrt(3.0) * std + init.uniform_(self.weight, -bound, bound) + elif init_weight == 'jax': + std = math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std + init.uniform_(self.weight, -a, a) + elif init_weight == 'zero': + init.zeros_(self.weight) # init bias - if not self.skip_bias_add: - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - with seed(ParallelMode.TENSOR): - nn.init.uniform_(self.bias, -bound, bound) + if self.bias is not None: + if init_bias == 'torch': + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + elif init_bias == 'jax': + init.normal_(self.bias, std=1e-6) + elif init_bias == 'zero': + init.zeros_(self.bias) + dist.broadcast(self.bias, + src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute(self.weight) - if self.bias is not None: - set_tensor_parallel_attribute(self.bias) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. @@ -214,41 +255,10 @@ def forward(self, input_: Tensor) -> Tensor: if not self.skip_bias_add: output = output + self.bias - return output - - - - -@LAYERS.register_module -class FusedLayerNormAffineFunction1D(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): - - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( - input_, ctx.normalized_shape, weight_, bias_, ctx.eps) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - - return output - - - @staticmethod - def backward(ctx, grad_output): - - input_, weight_, bias_, mean, invvar = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) - - return grad_input, grad_weight, grad_bias, None, None + return output + else: + return output, self.bias + @LAYERS.register_module @@ -257,10 +267,6 @@ class MixedFusedLayerNorm1D(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5): super(MixedFusedLayerNorm1D, self).__init__() - global fused_mix_prec_layer_norm_cuda - fused_mix_prec_layer_norm_cuda = importlib.import_module( - "fused_mix_prec_layer_norm_cuda") - if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) @@ -271,12 +277,10 @@ def __init__(self, normalized_shape, eps=1e-5): def reset_parameters(self): - init.ones_(self.weight) init.zeros_(self.bias) def forward(self, input): - return FusedLayerNormAffineFunction1D.apply( - input, self.weight, self.bias, self.normalized_shape,self.eps) \ No newline at end of file + input, self.weight, self.bias, self.normalized_shape,self.eps) diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index 1cfced732170..c3722b43e19c 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -85,25 +85,30 @@ def forward(ctx: Any, ctx.save_for_backward(A, B) A_shape = A.shape - A = A.reshape((-1, A_shape[-1])) + A = A.reshape((-1, A_shape[-1])).contiguous() B_shape = B.shape - B = B.reshape((-1, B_shape[-1])) + B = B.reshape((-1, B_shape[-1])).contiguous() C_shape = (A.shape[0], B.shape[-1]) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)] + B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)] + A_list.insert(gpc.get_local_rank(row_parallel_mode), A) + B_list.insert(gpc.get_local_rank(col_parallel_mode), B) + op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True) + op_a.wait() + op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True) + for op in [op_a, op_b]: + op.wait() + for i in range(summa_dim): - A_temp = A.clone() - B_temp = B.clone() - src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.broadcast(A_temp, src=src_a, - group=gpc.get_group(row_parallel_mode)) - src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.broadcast(B_temp, src=src_b, - group=gpc.get_group(col_parallel_mode)) + src_a = i + summa_dim * row_rank + src_b = i + summa_dim * col_rank + src_a = src_a % summa_dim + src_b = src_b % summa_dim + A_temp = A_list[src_a] + B_temp = B_list[src_b] torch.addmm(C, A_temp, B_temp, out=C) - out = C.reshape(out_shape) if ctx: diff --git a/colossalai/nn/layer/parallel_2d/_vit.py b/colossalai/nn/layer/parallel_2d/_vit.py index 2e608cea2318..3f8ed2a437aa 100644 --- a/colossalai/nn/layer/parallel_2d/_vit.py +++ b/colossalai/nn/layer/parallel_2d/_vit.py @@ -18,6 +18,7 @@ from .layers import Linear2D from .._common_utils import set_tensor_parallel_attribute from ..base_layer import ParallelLayer +from ..fused_bias_gelu import bias_gelu_impl @LAYERS.register_module @@ -55,15 +56,22 @@ def __init__(self, self.checkpoint = checkpoint assert weight_init in ('torch', 'jax') + if act_func == 'fused_gelu': + self.act = bias_gelu_impl + skip_dense_1_add_bias = True + else: + self.act = ACT2FN[act_func] + skip_dense_1_add_bias = False + # Project to mlp_ratio * h. self.dense_1 = Linear2D( self.in_features, self.mlp_ratio * self.in_features, dtype=dtype, - init_weight=weight_init, init_bias=weight_init + init_weight=weight_init, init_bias=weight_init, + skip_bias_add=skip_dense_1_add_bias ) - self.act = ACT2FN[act_func] # Project back to h. self.dense_2 = Linear2D( @@ -75,8 +83,12 @@ def __init__(self, self.dropout = nn.Dropout(dropout_prob) def _forward(self, hidden_states: Tensor) -> Tensor: - intermediate_output = self.dense_1(hidden_states) - intermediate_output = self.act(intermediate_output) + if self.act == bias_gelu_impl: + intermediate_output, bias = self.dense_1(hidden_states) + intermediate_output = self.act(intermediate_output, bias) + else: + intermediate_output = self.dense_1(hidden_states) + intermediate_output = self.act(intermediate_output) with seed(ParallelMode.TENSOR): intermediate_output = self.dropout(intermediate_output) @@ -270,19 +282,21 @@ def __init__(self, self.flatten = flatten self.embed_dim = embed_dim // (self.summa_dim ** 2) - self.proj = nn.Conv2d(in_chans, - self.embed_dim, - kernel_size=patch_size, - stride=patch_size, - device=get_current_device() - ) + with seed(ParallelMode.TENSOR): + self.proj = nn.Conv2d(in_chans, + self.embed_dim, + kernel_size=patch_size, + stride=patch_size, + device=get_current_device() + ) self._set_tensor_parallel_attribute() if weight_init == 'jax': - fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight) - std = math.sqrt(1.0 / fan_in) - nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978) - nn.init.zeros_(self.proj.bias) + with seed(ParallelMode.TENSOR): + fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight) + std = math.sqrt(1.0 / fan_in) + nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978) + nn.init.zeros_(self.proj.bias) def _set_tensor_parallel_attribute(self): set_tensor_parallel_attribute(self.proj.weight) @@ -356,7 +370,8 @@ def __init__(self, self.pos_embed = nn.Parameter(torch.empty( (1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)), device=get_current_device())) - nn.init.trunc_normal_(self.pos_embed, std=.02) + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) self._set_tensor_parallel_attribute() diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index 4af061f98f8c..396d55e239aa 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -73,7 +73,8 @@ def __init__(self, self.register_parameter('bias', None) # initialize parameters - self.reset_parameters(init_weight, init_bias) + with seed(ParallelMode.TENSOR): + self.reset_parameters(init_weight, init_bias) self._set_tensor_parallel_attributes() def _set_tensor_parallel_attributes(self): diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index e72f84bad417..6015c55c6dea 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1,7 +1,6 @@ from .base_loss import BaseLoss -from .cross_entropy_1d import CrossEntropyLoss1D from .cross_entropy_2d import CrossEntropyLoss2D from .cross_entropy_2p5d import CrossEntropyLoss2p5D from .cross_entropy_3d import CrossEntropyLoss3D -__all__ = ['CrossEntropyLoss1D', 'CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D'] +__all__ = ['CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D'] diff --git a/colossalai/nn/loss/cross_entropy_1d.py b/colossalai/nn/loss/cross_entropy_1d.py deleted file mode 100644 index 1b8349bbfaa3..000000000000 --- a/colossalai/nn/loss/cross_entropy_1d.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from colossalai.nn.layer.parallel_1d.layers import Linear1D_Col -from colossalai.utils.cuda import get_current_device -import torch -import torch.nn.functional as F -from torch.nn.modules.loss import _Loss -from colossalai.registry import LOSSES - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size, vocab_range_from_global_vocab_size - - -class _VocabParallelCrossEntropy_1D(torch.autograd.Function): - - @staticmethod - def forward(ctx, vocab_parallel_logits, target): - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - # print("logits_max shape:", logits_max.size()) - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) - # print("logits_max shape after all reduce:", logits_max.size()) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) - # print("vocab_parallel_logits shape:", vocab_parallel_logits.size()) - - # Get the partition's vocab indecies - # partition_vocab_size = vocab_parallel_logits.size()[-1] - partition_vocab_size = vocab_parallel_logits.size(-1) - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - vocab_start_index, vocab_end_index = vocab_range_from_global_vocab_size( - partition_vocab_size, rank, world_size) - # print("partition_vocab_size, rank, world_size, vocab_start_index, vocab_end_index: ", partition_vocab_size, rank, world_size, vocab_start_index, vocab_end_index) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - # logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size).contiguous() - # masked_target_1d = masked_target.view(-1).contiguous() - # arange_1d = torch.arange(start=0, end=logits_2d.size()[0], - # device=logits_2d.device) - arange_1d = torch.arange(start=0, end=vocab_parallel_logits.size()[0]) - # predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits = vocab_parallel_logits[arange_1d, masked_target] - # predicted_logits_1d = predicted_logits_1d.clone().contiguous() - # predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) - - # Sum of exponential of logits along vocab dimension across all GPUs. - # exp_logits = vocab_parallel_logits - # torch.exp(vocab_parallel_logits, out=exp_logits) - exp_logits = torch.exp(vocab_parallel_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) - - # Loss = log(sum(exp(logits))) - predicted-logit. - loss = torch.log(sum_exp_logits) - predicted_logits - - # Store softmax, target-mask and masked-target for backward pass. - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - ctx.save_for_backward(exp_logits, target_mask, masked_target) - - return loss - - @staticmethod - def backward(ctx, grad_output): - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d = ctx.saved_tensors - - # All the inputs have softmax as thier gradient. - grad_input = softmax - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=get_current_device()) - grad_2d[arange_1d, masked_target_1d] -= ( - 1.0 - target_mask.view(-1).float()) - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - - return grad_input, None - -@LOSSES.register_module -class LmLoss1D(_Loss): - - def forward(self, lm_logits, lm_labels, loss_mask): - lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels) - lm_loss = torch.sum( - lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() - return lm_loss - -@LOSSES.register_module -class SopLoss1D(_Loss): - - def forward(self, sop_logits, sentence_order): - sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), - sentence_order.view(-1), - ignore_index=-1) - return sop_loss - -@LOSSES.register_module -class BERTDualHeadLoss(_Loss): - - def __init__(self): - self.lm_loss = LmLoss1D() - self.sop_loss = SopLoss1D() - - def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order): - lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask) - sop_loss = self.sop_loss(sop_logits, sentence_order) - return lm_loss + sop_loss - -@LOSSES.register_module -class CrossEntropyLoss1D(_Loss): - """Cross entropy loss for 1D parallelism - - :param reduction: whether to average the loss, defaults to True - :type reduction: bool, optional - """ - - def __init__(self): - super().__init__() - self.dim = gpc.tensor_parallel_size - - def forward(self, logits, targets): - # loss = _VocabParallelCrossEntropy_1D.apply( - # logits, targets, - # ) - # print("loss :", loss.size()) - # print("loss contiguous or not: ", loss.is_contiguous()) - # return loss - - # dist_loss = loss.mean() - - # test tp=1 - loss = torch.nn.CrossEntropyLoss() - dist_loss = loss(logits, targets) - return dist_loss diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index a516592dd242..bc0205344298 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -1,8 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - import time - from .cuda import synchronize @@ -10,7 +8,6 @@ class Timer: ''' A timer object which helps to log the execution times, and provides different tools to assess the times. ''' - def __init__(self): self._started = False self._start_time = time.time() @@ -31,7 +28,6 @@ def start(self): def stop(self, keep_in_history: bool = False): '''Stop the timer and record the start-stop time interval. - :param keep_in_history: whether does it record into history each start-stop interval, defaults to False :type keep_in_history: bool, optional :return: start-stop interval @@ -48,7 +44,6 @@ def stop(self, keep_in_history: bool = False): def get_history_mean(self): '''mean of all history start-stop time intervals. - :return: mean of time intervals :rtype: int ''' @@ -56,7 +51,6 @@ def get_history_mean(self): def get_history_sum(self): '''add up all the start-stop time intervals. - :return: sum of time intervals :rtype: int ''' @@ -64,7 +58,6 @@ def get_history_sum(self): def get_elapsed_time(self): '''return the last start-stop time interval. *use it only when timer is not in progress* - :return: the last time interval :rtype: int ''' @@ -89,7 +82,6 @@ def __init__(self, on: bool = True): def start(self, name: str): '''Start namely one of the timers - :param name: timer's key :type name: str ''' @@ -100,7 +92,6 @@ def start(self, name: str): def stop(self, name: str, keep_in_history: bool): '''Stop namely one of the timers. - :param name: timer's key :param keep_in_history: whether does it record into history each start-stop interval :type keep_in_history: bool @@ -112,7 +103,6 @@ def stop(self, name: str, keep_in_history: bool): def get_timer(self, name): '''Get timer by its name (from multitimer) - :param name: timer's key :return: timer with the name you give correctly :rtype: Timer @@ -121,7 +111,6 @@ def get_timer(self, name): def reset(self, name=None): '''Reset timers. - :param name: if name is designated, the named timer will be reset and others will not, defaults to None ''' if self._on: @@ -132,7 +121,6 @@ def reset(self, name=None): timer.reset() def is_on(self): - return self._on def set_status(self, mode: bool): @@ -140,4 +128,4 @@ def set_status(self, mode: bool): def __iter__(self): for name, timer in self._timers.items(): - yield name, timer + yield name, timer \ No newline at end of file