From 1251d1b57b42d781646108529385b804684f289f Mon Sep 17 00:00:00 2001 From: WANG-CR Date: Mon, 29 Nov 2021 13:30:49 +0800 Subject: [PATCH] Integrate 1d tensor parallel in Colossal-AI --- colossalai/nn/layer/parallel_1d/__init__.py | 8 +- .../nn/layer/parallel_1d/_transformer.py | 220 +++++++++ colossalai/nn/layer/parallel_1d/_utils.py | 3 + colossalai/nn/layer/parallel_1d/_vit.py | 463 ++++++++++++++++++ colossalai/nn/layer/parallel_1d/layers.py | 144 +++++- colossalai/nn/loss/__init__.py | 3 +- colossalai/nn/loss/cross_entropy_1d.py | 75 ++- colossalai/nn/loss/cross_entropy_2d.py | 4 +- colossalai/trainer/hooks/_metric_hook.py | 32 +- colossalai/trainer/metric.py | 33 +- model_zoo/vit/parallel_1d/vit.py | 208 ++++++++ tests/test_layers/test_1d/common.py | 3 +- tests/test_layers/test_1d/test_1d.py | 16 +- tests/test_layers/test_1d/test_layer.py | 345 ++++++++++--- .../test_vision_transformer/configs/vit_1d.py | 137 ++++++ .../test_vit_1d/test_vit_1d.py | 104 ++++ 16 files changed, 1680 insertions(+), 118 deletions(-) create mode 100644 colossalai/nn/layer/parallel_1d/_transformer.py create mode 100644 colossalai/nn/layer/parallel_1d/_vit.py create mode 100644 model_zoo/vit/parallel_1d/vit.py create mode 100644 tests/test_models/test_vision_transformer/configs/vit_1d.py create mode 100644 tests/test_models/test_vision_transformer/test_vit_1d/test_vit_1d.py diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py index 9e7df549fdb0..72e6e5a7c82c 100644 --- a/colossalai/nn/layer/parallel_1d/__init__.py +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -1,5 +1,11 @@ from .layers import Linear1D_Col, Linear1D_Row +from .layers import MixedFusedLayerNorm1D as LayerNorm1D +from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D +from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHeadNormal, ViTSelfAttention1DV2 + + __all__ = [ - 'Linear1D_Col', 'Linear1D_Row', + 'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D', + 'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHeadNormal', 'ViTSelfAttention1DV2' ] diff --git a/colossalai/nn/layer/parallel_1d/_transformer.py b/colossalai/nn/layer/parallel_1d/_transformer.py new file mode 100644 index 000000000000..90a8d740eea5 --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/_transformer.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +import math +from torch import Tensor +from torch.nn.parameter import Parameter +from typing import Tuple + +from colossalai.context import seed, ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import LAYERS +from colossalai.utils import get_current_device +from .._common_utils import divide, ACT2FN +from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \ + split_forward_gather_backward +from ..base_layer import ParallelLayer +from .layers import Linear1D_Col, Linear1D_Row +from .layers import MixedFusedLayerNorm1D as LayerNorm1D + +@LAYERS.register_module +class TransformerMLP1D(ParallelLayer): + """MLP. + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, + in_features: int, + mlp_ratio: int = 4.0, + act_func: str = 'gelu', + dropout_prob: float = 0., + dtype=None, + skip_bias_add: bool = False + ): + super(TransformerMLP1D, self).__init__() + self.in_features = in_features + self.mlp_ratio = mlp_ratio + self.skip_bias_add = skip_bias_add + # Project to h * mlp_ratio. + self.dense_1 = Linear1D_Col( + self.in_features, + int(self.mlp_ratio * self.in_features), + bias=not skip_bias_add, + dtype=dtype, + gather_output = False, + ) + + assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \ + f'activation function can only be {list(ACT2FN.keys())}' + self.activation_func = ACT2FN[act_func] + + # Project back to h. + self.dense_2 = Linear1D_Row( + int(self.mlp_ratio * self.in_features), + self.in_features, + bias=not skip_bias_add, + dtype=dtype, + parallel_input = True, + ) + self.dropout = nn.Dropout(dropout_prob) + # self.layernorm = LayerNorm1D(in_features, dtype=dtype) + self.layernorm = nn.LayerNorm(in_features, dtype=dtype) + def forward(self, x): + if self.skip_bias_add: + intermediate_output, _ = self.dense_1(x) + else: + intermediate_output = self.dense_1(x) + + intermediate_output = self.activation_func(intermediate_output) + + if self.skip_bias_add: + output, _ = self.dense_2(intermediate_output) + else: + output = self.dense_2(intermediate_output) + + with seed(ParallelMode.TENSOR): + output = self.dropout(output) + output = self.layernorm(x + output) + return output + +@LAYERS.register_module +class TransformerSelfAttention1D(ParallelLayer): + """Self attention layer for 1D parallel Transformer + + :param hidden_size: hidden size + :type hidden_size: int + :param num_attention_heads: number of attention heads + :type num_attention_heads: int + :param attention_dropout_prob: dropout probability for attention layer + :type attention_dropout_prob: float + :param hidden_dropout_prob: dropout probability for hidden layer + :type hidden_dropout_prob: float + :param dtype: dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + """ + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + ): + + super().__init__() + + self.hidden_size = hidden_size + + self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size) + self.attention_head_size = divide(hidden_size, num_attention_heads) + self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size) + + self.query_key_value = Linear1D_Col( + hidden_size, + 3 * hidden_size, + dtype=dtype, + ) + self.attention_dropout = nn.Dropout(attention_dropout_prob) + self.dense = Linear1D_Row( + hidden_size, + hidden_size, + dtype=dtype, + parallel_input=True, + ) + self.dropout = nn.Dropout(hidden_dropout_prob) + + # need to re-enable torch grad to enable fused optimization. + # self.layernorm = LayerNorm1D( + # hidden_size, + # dtype=dtype) + self.layernorm = nn.LayerNorm( + hidden_size, + dtype=dtype) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + query_key_value = self.query_key_value(hidden_states) + new_qkv_shape = query_key_value.shape[:-1] + \ + (self.num_attention_heads, 3 * self.attention_head_size) + query_key_value = query_key_value.view(new_qkv_shape) + query_key_value = query_key_value.permute((0, 2, 1, 3)) + query_layer, key_layer, value_layer = torch.chunk( + query_key_value, 3, dim=-1) + + attention_scores = torch.matmul( + query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / \ + math.sqrt(self.attention_head_size) + attention_scores = attention_scores + attention_mask + attention_probs = nn.Softmax(dim=-1)(attention_scores) + with seed(ParallelMode.TENSOR): + attention_probs = self.attention_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute((0, 2, 1, 3)).contiguous() + new_context_layer_shape = context_layer.size()[ + :-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.dense(context_layer) + with seed(ParallelMode.TENSOR): + output = self.dropout(output) + attention_output = self.layernorm(hidden_states + output) + + return attention_output + +@LAYERS.register_module +class TransformerLayer1D(ParallelLayer): + """Transformer layer which contains a self-attention layer and a MLP layer + + :param hidden_size: hidden size + :type hidden_size: int + :param num_attention_heads: number of attention heads + :type num_attention_heads: int + :param act_func: activation function, defaults to 'gelu' + :type act_func: str, optional + :param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0 + :type mlp_ratio: float, optional + :param attention_dropout_prob: dropout probability for attention layer, defaults to 0. + :type attention_dropout_prob: float, optional + :param hidden_dropout_prob: dropout probability for attention layer, defaults to 0. + :type hidden_dropout_prob: float, optional + :param dtype: dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + """ + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + act_func: str = 'gelu', + mlp_ratio: float = 4.0, + attention_dropout_prob: float = 0., + hidden_dropout_prob: float = 0., + dtype=None, + ): + super().__init__() + + self.attention = TransformerSelfAttention1D( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + ) + self.mlp = TransformerMLP1D( + in_features=hidden_size, + dropout_prob=hidden_dropout_prob, + act_func=act_func, + mlp_ratio=mlp_ratio, + dtype=dtype, + ) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + attention_output = self.attention(hidden_states, attention_mask) + output = self.mlp(attention_output) + return output diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py index 00d221e786f6..3e1afa1865f0 100644 --- a/colossalai/nn/layer/parallel_1d/_utils.py +++ b/colossalai/nn/layer/parallel_1d/_utils.py @@ -13,3 +13,6 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): per_partition_vocab_size = divide(global_vocab_size, world_size) return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank) + + + diff --git a/colossalai/nn/layer/parallel_1d/_vit.py b/colossalai/nn/layer/parallel_1d/_vit.py new file mode 100644 index 000000000000..67a6ca2c8157 --- /dev/null +++ b/colossalai/nn/layer/parallel_1d/_vit.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from colossalai import context + +import torch +from torch import nn as nn, Tensor, distributed as dist + +from colossalai.context import seed, ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.layer._common_utils import divide, ACT2FN +from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple +from colossalai.registry import LAYERS +from colossalai.utils import checkpoint +from colossalai.utils import get_current_device +from .layers import Linear1D_Col, Linear1D_Row +from .._common_utils import set_tensor_parallel_attribute +from ..base_layer import ParallelLayer + + +@LAYERS.register_module +class ViTMLP1D(ParallelLayer): + """MLP layer for 1D parallel Vision Transformer + + :param in_features: size of each input sample + :type in_features: int + :param mlp_ratio: hidden size of MLP divided by embedding dim + :type mlp_ratio: int + :param act_func: activation function, defaults to 'gelu' + :type act_func: str, optional + :param dropout_prob: dropout probability, defaults to 0. + :type dropout_prob: float, optional + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param checkpoint: whether to checkpoint the layer, defaults to False + :type checkpoint: bool, optional + """ + + def __init__(self, + in_features: int, + mlp_ratio: int, + act_func: str = 'gelu', + dropout_prob: float = 0., + dtype=None, + checkpoint: bool = False, + skip_bias_add: bool = False + ): + super().__init__() + + self.in_features = in_features + self.mlp_ratio = mlp_ratio + self.checkpoint = checkpoint + self.skip_bias_add = skip_bias_add + self.bias = not skip_bias_add + # Project to mlp_ratio * h. + self.dense_1 = Linear1D_Col( + self.in_features, + int(self.mlp_ratio * self.in_features), + bias=not skip_bias_add, + dtype=dtype, + gather_output = False, + ) + + self.act = ACT2FN[act_func] + + # Project back to h. + self.dense_2 = Linear1D_Row( + int(self.mlp_ratio * self.in_features), + self.in_features, + bias=not skip_bias_add, + dtype=dtype, + parallel_input = True, + ) + + self.dropout = nn.Dropout(dropout_prob) + + def _forward(self, hidden_states: Tensor) -> Tensor: + intermediate_output = self.dense_1(hidden_states) + intermediate_output = self.act(intermediate_output) + + with seed(ParallelMode.TENSOR): + intermediate_output = self.dropout(intermediate_output) + output = self.dense_2(intermediate_output) + + with seed(ParallelMode.TENSOR): + output = self.dropout(output) + return output + + def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: + return checkpoint(self._forward, hidden_states) + + def forward(self, hidden_states: Tensor) -> Tensor: + if self.checkpoint: + return self._checkpoint_forward(hidden_states) + else: + return self._forward(hidden_states) + + +@LAYERS.register_module +class ViTSelfAttention1D(ParallelLayer): + """Self-attention layer for 1D parallel Vision Transformer + + :param hidden_size: hidden size + :type hidden_size: int + :param num_attention_heads: number of attention heads + :type num_attention_heads: int + :param attention_dropout_prob: dropout probability for attention layers + :type attention_dropout_prob: float + :param hidden_dropout_prob: dropout probability for hidden layers + :type hidden_dropout_prob: float + :param dtype: dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param checkpoint: whether to checkpoint the layer, defaults to False + :type checkpoint: bool, optional + """ + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False + ): + super().__init__() + + self.hidden_size = hidden_size + self.attention_head_size = divide(hidden_size, num_attention_heads) + self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size) + self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size) + + self.checkpoint = checkpoint + + self.query_key_value = Linear1D_Col( + hidden_size, + 3 * hidden_size, + dtype=dtype, + ) + self.attention_dropout = nn.Dropout(attention_dropout_prob) + self.dense = Linear1D_Row( + hidden_size, + hidden_size, + dtype=dtype, + parallel_input=True, + ) + self.dropout = nn.Dropout(hidden_dropout_prob) + self.softmax = nn.Softmax(dim=-1) + + def _forward(self, hidden_states: Tensor) -> Tensor: + query_key_value = self.query_key_value(hidden_states) + new_qkv_shape = query_key_value.shape[:-1] + \ + (self.num_attention_heads_per_partition, 3 * self.attention_head_size) + query_key_value = query_key_value.view(new_qkv_shape) + query_key_value = query_key_value.permute((0, 2, 1, 3)) + query_layer, key_layer, value_layer = torch.chunk( + query_key_value, 3, dim=-1) + + attention_scores = torch.matmul( + query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / \ + math.sqrt(self.attention_head_size) + + attention_probs = self.softmax(attention_scores) + + with seed(ParallelMode.TENSOR): + attention_probs = self.attention_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.transpose(1, 2) + new_context_layer_shape = context_layer.size()[ + :-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(new_context_layer_shape) + output = self.dense(context_layer) + with seed(ParallelMode.TENSOR): + output = self.dropout(output) + return output + + def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: + return checkpoint(self._forward, hidden_states) + + def forward(self, hidden_states: Tensor) -> Tensor: + if self.checkpoint: + return self._checkpoint_forward(hidden_states) + else: + return self._forward(hidden_states) + +@LAYERS.register_module +class ViTSelfAttention1DV2(ParallelLayer): + """Self-attention layer for 1D parallel Vision Transformer + + :param hidden_size: hidden size + :type hidden_size: int + :param num_attention_heads: number of attention heads + :type num_attention_heads: int + :param attention_dropout_prob: dropout probability for attention layers + :type attention_dropout_prob: float + :param hidden_dropout_prob: dropout probability for hidden layers + :type hidden_dropout_prob: float + :param dtype: dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param checkpoint: whether to checkpoint the layer, defaults to False + :type checkpoint: bool, optional + """ + + def __init__(self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size) + self.attention_head_size = divide(hidden_size, num_attention_heads) + self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size) + + self.checkpoint = checkpoint + + self.query_key_value = Linear1D_Col( + hidden_size, + 3 * hidden_size, + dtype=dtype, + ) + self.attention_dropout = nn.Dropout(attention_dropout_prob) + self.dense = Linear1D_Row( + hidden_size, + hidden_size, + dtype=dtype, + parallel_input=True, + ) + self.dropout = nn.Dropout(hidden_dropout_prob) + self.softmax = nn.Softmax(dim=-1) + self.scale = self.num_attention_heads ** -0.5 + + def _forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.query_key_value(x).reshape(B, N, 3, self.num_attention_heads, C // + (self.num_attention_heads * gpc.tensor_parallel_size)).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + with seed(ParallelMode.TENSOR): + attn = self.attention_dropout(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C // gpc.tensor_parallel_size) + x = self.dense(x) + with seed(ParallelMode.TENSOR): + x = self.dropout(x) + return x + + def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: + return checkpoint(self._forward, hidden_states) + + def forward(self, hidden_states: Tensor) -> Tensor: + if self.checkpoint: + return self._checkpoint_forward(hidden_states) + else: + return self._forward(hidden_states) + +@LAYERS.register_module +class ViTHead1D(ParallelLayer): + """Output layer for 1D parallel Vision Transformer + + :param hidden_size: hidden size + :type hidden_size: int + :param num_classes: number of classes + :type num_classes: int + :param dtype: dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + """ + + def __init__(self, + hidden_size, + num_classes, + dtype=None, + ): + super().__init__() + + self.linear = Linear1D_Col( + hidden_size, + num_classes, + dtype=dtype, + gather_output = True, + ) + + def forward(self, x: Tensor) -> Tensor: + x = x[:, 0] + x = self.linear(x) + return x + +@LAYERS.register_module +class ViTHeadNormal(ParallelLayer): + """Output layer for 1D parallel Vision Transformer + + :param hidden_size: hidden size + :type hidden_size: int + :param num_classes: number of classes + :type num_classes: int + :param dtype: dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + """ + + def __init__(self, + hidden_size, + num_classes, + dtype=None, + ): + super().__init__() + + self.linear = nn.Linear( + hidden_size, + num_classes, + dtype = dtype + ) + self._broadcast_linear_params() + + def _broadcast_linear_params(self) -> None: + self.to(get_current_device()) + ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D) + + dist.broadcast(self.linear.weight, src=ranks[0], + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + dist.broadcast(self.linear.bias, src=ranks[0], + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + def forward(self, x: Tensor) -> Tensor: + x = x[:, 0] + x = self.linear(x) + return x + +@LAYERS.register_module +class ViTPatchEmbedding1D(ParallelLayer): + """ 2D Image to Patch Embedding + + :param img_size: iamge size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param embed_dim: dimension of embedding + :type embed_dim: int + :param in_chans: number of channels of input image, defaults to 3 + :type in_chans: int, optional + :param flatten: whether to flatten output tensor, defaults to True + :type flatten: bool, optional + """ + + def __init__(self, + img_size, + patch_size, + embed_dim, + in_chans=3, + flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.embed_dim = embed_dim + + with seed(ParallelMode.TENSOR): + self.proj = nn.Conv2d(in_chans, + self.embed_dim, + kernel_size=patch_size, + stride=patch_size + ) + + # sync + self._broadcast_conv_params() + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute(self.proj.weight) + set_tensor_parallel_attribute(self.proj.bias) + + def _broadcast_conv_params(self) -> None: + self.to(get_current_device()) + ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D) + + dist.broadcast(self.proj.weight, src=ranks[0], + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + dist.broadcast(self.proj.bias, src=ranks[0], + group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + def forward(self, x: Tensor) -> Tensor: + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + return x + + +@LAYERS.register_module +class ViTTokenFuser1D(ParallelLayer): + """ + Fuse cls token and pos embedding to the input + + :param img_size: image size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param embed_dim: dimension of embedding + :type embed_dim: int + :param drop_rate: dropout probability, defaults to 0. + :type drop_rate: float, optional + """ + + def __init__(self, + img_size, + patch_size, + embed_dim, + drop_rate=0. + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.embed_dim = embed_dim + + self.cls_token = nn.Parameter(torch.zeros( + 1, 1, self.embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros( + 1, self.num_patches + 1, self.embed_dim)) + + # move to cuda before broadcast + self.to(get_current_device()) + + # sync param in both forward and backward + _cls_token = self.cls_token.view(-1) + _pos_embed = self.pos_embed.view(-1) + self._param = torch.cat([_cls_token, _pos_embed], dim=0) + + self.pos_drop = nn.Dropout(p=drop_rate) + self._set_tensor_parallel_attribute() + + def _set_tensor_parallel_attribute(self): + set_tensor_parallel_attribute(self.cls_token) + set_tensor_parallel_attribute(self.pos_embed) + + def forward(self, x: Tensor) -> Tensor: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + with seed(ParallelMode.TENSOR): + x = self.pos_drop(x + self.pos_embed) + return x.contiguous() + diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 572eca7775ee..b67d0a2c1f42 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import math +import numbers import torch import torch.nn as nn import torch.nn.functional as F @@ -8,17 +10,22 @@ from torch import Tensor from torch.nn.parameter import Parameter from typing import Tuple +import importlib -from colossalai.context.parallel_mode import ParallelMode +from colossalai.context import seed, ParallelMode from colossalai.core import global_context as gpc from colossalai.registry import LAYERS from colossalai.utils import get_current_device -from .._common_utils import divide +from .._common_utils import divide, set_tensor_parallel_attribute from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \ split_forward_gather_backward from ..base_layer import ParallelLayer +global fused_mix_prec_layer_norm_cuda +fused_mix_prec_layer_norm_cuda = None + +@LAYERS.register_module class Linear1D_Col(ParallelLayer): """Linear layer with column parallelism. @@ -48,19 +55,18 @@ def __init__(self, super().__init__() # Keep input parameters - self.input_size = in_features - self.output_size = output_size + self.in_features = in_features + self.out_features = output_size self.gather_output = gather_output self.skip_bias_add = not bias - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - self.output_size_per_partition = divide(output_size, world_size) + self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size) # Parameters. # Initialize weight. factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.empty( - self.output_size_per_partition, self.input_size, + self.output_size_per_partition, self.in_features, **factory_kwargs)) if bias: @@ -72,6 +78,30 @@ def __init__(self, self.bias.zero_() else: self.register_parameter('bias', None) + self.reset_parameters() + self._set_tensor_parallel_attributes() + + def reset_parameters(self) -> None: + fan_in = self.in_features + a = math.sqrt(5) + nonlinearity = 'leaky_relu' + + # init weight + std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) + bound = math.sqrt(3.0) * std + with seed(ParallelMode.TENSOR): + nn.init.uniform_(self.weight, -bound, bound) + + # init bias + if not self.skip_bias_add: + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + with seed(ParallelMode.TENSOR): + nn.init.uniform_(self.bias, -bound, bound) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute(self.weight) + if self.bias is not None: + set_tensor_parallel_attribute(self.bias) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Set up backprop all-reduce. @@ -104,7 +134,7 @@ class Linear1D_Row(ParallelLayer): :type bias: bool, optional :param dtype: The dtype of parameters, defaults to None :type dtype: torch.dtype, optional - :param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False + :param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False :type parallel_input: bool, optional """ @@ -124,15 +154,14 @@ def __init__(self, self.skip_bias_add = not bias # Divide the weight matrix along the last dimension. - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - self.input_size_per_partition = divide(in_features, world_size) + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.empty( - self.out_features, - self.input_size_per_partition, + self.out_features, + self.input_size_per_partition, **factory_kwargs)) if bias: @@ -146,9 +175,31 @@ def __init__(self, self.bias.zero_() else: self.register_parameter('bias', None) - + self.reset_parameters() + self._set_tensor_parallel_attributes() + def reset_parameters(self) -> None: - init.xavier_normal_(self.weight) + # setting + fan_in = self.in_features + a = math.sqrt(5) + nonlinearity = 'leaky_relu' + + # init weight + std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in) + bound = math.sqrt(3.0) * std + with seed(ParallelMode.TENSOR): + nn.init.uniform_(self.weight, -bound, bound) + + # init bias + if not self.skip_bias_add: + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + with seed(ParallelMode.TENSOR): + nn.init.uniform_(self.bias, -bound, bound) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute(self.weight) + if self.bias is not None: + set_tensor_parallel_attribute(self.bias) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. @@ -164,3 +215,68 @@ def forward(self, input_: Tensor) -> Tensor: if not self.skip_bias_add: output = output + self.bias return output + + + + +@LAYERS.register_module +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + + return output + + + @staticmethod + def backward(ctx, grad_output): + + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +@LAYERS.register_module +class MixedFusedLayerNorm1D(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5): + super(MixedFusedLayerNorm1D, self).__init__() + + global fused_mix_prec_layer_norm_cuda + fused_mix_prec_layer_norm_cuda = importlib.import_module( + "fused_mix_prec_layer_norm_cuda") + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + + + def reset_parameters(self): + + init.ones_(self.weight) + init.zeros_(self.bias) + + + def forward(self, input): + + return FusedLayerNormAffineFunction1D.apply( + input, self.weight, self.bias, self.normalized_shape,self.eps) \ No newline at end of file diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 6015c55c6dea..e72f84bad417 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1,6 +1,7 @@ from .base_loss import BaseLoss +from .cross_entropy_1d import CrossEntropyLoss1D from .cross_entropy_2d import CrossEntropyLoss2D from .cross_entropy_2p5d import CrossEntropyLoss2p5D from .cross_entropy_3d import CrossEntropyLoss3D -__all__ = ['CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D'] +__all__ = ['CrossEntropyLoss1D', 'CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D'] diff --git a/colossalai/nn/loss/cross_entropy_1d.py b/colossalai/nn/loss/cross_entropy_1d.py index 667c007344c1..1b8349bbfaa3 100644 --- a/colossalai/nn/loss/cross_entropy_1d.py +++ b/colossalai/nn/loss/cross_entropy_1d.py @@ -1,13 +1,16 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from colossalai.nn.layer.parallel_1d.layers import Linear1D_Col +from colossalai.utils.cuda import get_current_device import torch import torch.nn.functional as F from torch.nn.modules.loss import _Loss +from colossalai.registry import LOSSES from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size +from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size, vocab_range_from_global_vocab_size class _VocabParallelCrossEntropy_1D(torch.autograd.Function): @@ -16,18 +19,24 @@ class _VocabParallelCrossEntropy_1D(torch.autograd.Function): def forward(ctx, vocab_parallel_logits, target): # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + # print("logits_max shape:", logits_max.size()) torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + # print("logits_max shape after all reduce:", logits_max.size()) # Subtract the maximum value. - vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) + # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + # print("vocab_parallel_logits shape:", vocab_parallel_logits.size()) # Get the partition's vocab indecies - partition_vocab_size = vocab_parallel_logits.size()[-1] + # partition_vocab_size = vocab_parallel_logits.size()[-1] + partition_vocab_size = vocab_parallel_logits.size(-1) rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size( + vocab_start_index, vocab_end_index = vocab_range_from_global_vocab_size( partition_vocab_size, rank, world_size) + # print("partition_vocab_size, rank, world_size, vocab_start_index, vocab_end_index: ", partition_vocab_size, rank, world_size, vocab_start_index, vocab_end_index) # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) @@ -37,13 +46,15 @@ def forward(ctx, vocab_parallel_logits, target): # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], - device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) + # logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size).contiguous() + # masked_target_1d = masked_target.view(-1).contiguous() + # arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + # device=logits_2d.device) + arange_1d = torch.arange(start=0, end=vocab_parallel_logits.size()[0]) + # predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits = vocab_parallel_logits[arange_1d, masked_target] + # predicted_logits_1d = predicted_logits_1d.clone().contiguous() + # predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, @@ -51,8 +62,9 @@ def forward(ctx, vocab_parallel_logits, target): group=gpc.get_group(ParallelMode.PARALLEL_1D)) # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) + # exp_logits = vocab_parallel_logits + # torch.exp(vocab_parallel_logits, out=exp_logits) + exp_logits = torch.exp(vocab_parallel_logits) sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, @@ -63,7 +75,7 @@ def forward(ctx, vocab_parallel_logits, target): # Store softmax, target-mask and masked-target for backward pass. exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + ctx.save_for_backward(exp_logits, target_mask, masked_target) return loss @@ -80,7 +92,7 @@ def backward(ctx, grad_output): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=grad_2d.device) + device=get_current_device()) grad_2d[arange_1d, masked_target_1d] -= ( 1.0 - target_mask.view(-1).float()) @@ -89,7 +101,7 @@ def backward(ctx, grad_output): return grad_input, None - +@LOSSES.register_module class LmLoss1D(_Loss): def forward(self, lm_logits, lm_labels, loss_mask): @@ -98,7 +110,7 @@ def forward(self, lm_logits, lm_labels, loss_mask): lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() return lm_loss - +@LOSSES.register_module class SopLoss1D(_Loss): def forward(self, sop_logits, sentence_order): @@ -107,7 +119,7 @@ def forward(self, sop_logits, sentence_order): ignore_index=-1) return sop_loss - +@LOSSES.register_module class BERTDualHeadLoss(_Loss): def __init__(self): @@ -118,3 +130,30 @@ def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order): lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask) sop_loss = self.sop_loss(sop_logits, sentence_order) return lm_loss + sop_loss + +@LOSSES.register_module +class CrossEntropyLoss1D(_Loss): + """Cross entropy loss for 1D parallelism + + :param reduction: whether to average the loss, defaults to True + :type reduction: bool, optional + """ + + def __init__(self): + super().__init__() + self.dim = gpc.tensor_parallel_size + + def forward(self, logits, targets): + # loss = _VocabParallelCrossEntropy_1D.apply( + # logits, targets, + # ) + # print("loss :", loss.size()) + # print("loss contiguous or not: ", loss.is_contiguous()) + # return loss + + # dist_loss = loss.mean() + + # test tp=1 + loss = torch.nn.CrossEntropyLoss() + dist_loss = loss(logits, targets) + return dist_loss diff --git a/colossalai/nn/loss/cross_entropy_2d.py b/colossalai/nn/loss/cross_entropy_2d.py index 57ac4e985804..3bb5712aa177 100644 --- a/colossalai/nn/loss/cross_entropy_2d.py +++ b/colossalai/nn/loss/cross_entropy_2d.py @@ -18,9 +18,7 @@ class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function): def forward(ctx, logits, targets): # logits: [b/q, h/q] # labels: [b/q] - # loss: [b/q] - # vocab_parallel_logits: [b/q, s, v/q] - # target: [b/q, s] + logits_max = torch.max(logits, dim=-1)[0] torch.distributed.all_reduce( logits_max, diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 8c3478c71336..cf345d0f8117 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -6,7 +6,7 @@ from colossalai.utils import is_no_pp_or_last_stage from ._base_hook import BaseHook from .._trainer import Trainer -from ..metric import Loss, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D +from ..metric import Loss, Accuracy1D, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D class MetricHook(BaseHook): @@ -74,6 +74,36 @@ def after_test_iter(self, logits, label, loss): self.test_loss.update(loss) +@HOOKS.register_module +class Accuracy1DHook(MetricHook): + """Specialized hook class for :class:`Accuracy1D`. + It acts the same as :class:`AccuracyHook`. + + :param trainer: Trainer attached with current hook + :param priority: Priority in the printing, hooks with small priority will be printed in front + :type trainer: Trainer + :type priority: int, optional + """ + + def __init__(self, trainer: Trainer, priority: int = 10): + super().__init__(trainer, priority) + + if self._is_stage_to_compute: + self.metric = Accuracy1D(epoch_only=True) + + # register the metric + self.trainer.states['metrics']['test'][ + self.metric.__class__.__name__] = self.metric + + def before_test(self): + if self._is_stage_to_compute: + self.metric.reset() + + def after_test_iter(self, logits, label, *args): + if self._is_stage_to_compute: + self.metric.update(logits, label) + + @HOOKS.register_module class Accuracy2DHook(MetricHook): """Specialized hook class for :class:`Accuracy2D`. diff --git a/colossalai/trainer/metric.py b/colossalai/trainer/metric.py index b595d37b823c..d0255b4ea3b3 100644 --- a/colossalai/trainer/metric.py +++ b/colossalai/trainer/metric.py @@ -211,7 +211,6 @@ def get_accumulated_value(self): def is_better(a, b) -> bool: return a > b - class Accuracy2D(Accuracy): """A metric collector for accuracy. It only works for classification tasks. This class is the same as :class:`Accuracy` but used in 2D @@ -248,6 +247,38 @@ def update(self, logits, label) -> None: self.accumulated_sum += self.last_step_sum self.accumulated_correct += self.last_step_correct +class Accuracy1D(Accuracy): + """A metric collector for accuracy. It only works for classification + tasks. This class is the same as :class:`Accuracy` but used in 2D + model parallelism. + + :param epoch_only: Whether the metric only read for the full epoch + :type epoch_only: bool + """ + + def __init__(self, epoch_only: bool): + super().__init__(epoch_only=epoch_only) + + def update(self, logits, label) -> None: + if isinstance(logits, (list, tuple)): + logits = logits[0] + if isinstance(label, (list, tuple)): + label = label[0] + + logits = _gather( + logits, + ParallelMode.PARALLEL_1D, + 1 + ) + + # update + preds = torch.argmax(logits, dim=-1) + correct = torch.sum(label == preds) + self.last_step_sum.fill_(label.size(0)) + self.last_step_correct.fill_(correct) + self.accumulated_sum += self.last_step_sum + self.accumulated_correct += self.last_step_correct + class Accuracy2p5D(Accuracy): def __init__(self, epoch_only: bool): diff --git a/model_zoo/vit/parallel_1d/vit.py b/model_zoo/vit/parallel_1d/vit.py new file mode 100644 index 000000000000..e471fed143bf --- /dev/null +++ b/model_zoo/vit/parallel_1d/vit.py @@ -0,0 +1,208 @@ +import torch +from torch import nn + +from colossalai import nn as col_nn +from colossalai.context import ParallelMode +from colossalai.registry import MODELS + +__all__ = [ + 'VisionTransformer3D', + 'vit_tiny_1d_patch4_32', + 'vit_tiny_1d_patch16_224', + 'vit_tiny_1d_patch16_384', + 'vit_small_1d_patch16_224', + 'vit_small_1d_patch16_384', + 'vit_small_1d_patch32_224', + 'vit_small_1d_patch32_384', + 'vit_base_1d_patch16_224', + 'vit_base_1d_patch16_384', + 'vit_base_1d_patch32_224', + 'vit_base_1d_patch32_384', + 'vit_large_1d_patch16_224', + 'vit_large_1d_patch16_384', + 'vit_large_1d_patch32_224', + 'vit_large_1d_patch32_384', +] + + +class ViTBlock1D(nn.Module): + def __init__(self, + dim: int, + num_heads: int, + hidden_dim: int, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0.): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = col_nn.ViTSelfAttention1D(dim, num_heads, attn_drop, drop) + self.drop_path = col_nn.VanillaViTDropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + self.mlp = col_nn.ViTMLP1D(dim, 1, drop, 'gelu') + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +@MODELS.register_module +class VisionTransformer1D(nn.Module): + def __init__(self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + depth: int = 12, + num_heads: int = 12, + embed_dim: int = 768, + hidden_dim: int = 3072, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = col_nn.ViTPatchEmbedding1D( + img_size, + patch_size, + in_chans, + embed_dim, + drop_rate, + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = nn.Sequential(*[ + ViTBlock1D(embed_dim, num_heads, hidden_dim, + drop_rate, attn_drop_rate, dpr[i]) + for i in range(depth) + ]) + + self.norm = nn.LayerNorm(embed_dim, ParallelMode.PARALLEL_3D_INPUT, + ParallelMode.PARALLEL_3D_WEIGHT) + + self.head = col_nn.ViTHead1D(hidden_dim, num_classes) + self.init_weights() + + def init_weights(self): + pass + + def forward(self, x): + x = self.patch_embed(x) + x = self.blocks(x) + x = self.norm(x) + x = self.head(x) + return x + + +def _create_vit_model(**model_kwargs): + model = VisionTransformer1D(**model_kwargs) + return model + + +@MODELS.register_module +def vit_tiny_1d_patch4_32(**kwargs): + model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512, + depth=6, num_heads=8, hidden_dim=512, num_classes=10, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_tiny_1d_patch16_224(**kwargs): + model_kwargs = dict(patch_size=16, embed_dim=192, + depth=12, num_heads=3, hidden_dim=768, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_tiny_1d_patch16_384(**kwargs): + model_kwargs = dict(img_size=384, patch_size=16, + embed_dim=192, depth=12, num_heads=3, hidden_dim=768, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_small_1d_patch16_224(**kwargs): + model_kwargs = dict(patch_size=16, embed_dim=384, + depth=12, num_heads=6, hidden_dim=1536, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_small_1d_patch16_384(**kwargs): + model_kwargs = dict(img_size=384, patch_size=16, + embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_small_1d_patch32_224(**kwargs): + model_kwargs = dict(patch_size=32, embed_dim=384, + depth=12, num_heads=6, hidden_dim=1536, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_small_1d_patch32_384(**kwargs): + model_kwargs = dict(img_size=384, patch_size=32, + embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_base_1d_patch16_224(**kwargs): + model_kwargs = dict(patch_size=16, embed_dim=768, + depth=12, num_heads=12, hidden_dim=3072, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_base_1d_patch16_384(**kwargs): + model_kwargs = dict(img_size=384, patch_size=16, + embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_base_3d_patch32_224(**kwargs): + model_kwargs = dict(patch_size=32, embed_dim=768, + depth=12, num_heads=12, hidden_dim=3072, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_base_1d_patch32_384(**kwargs): + model_kwargs = dict(img_size=384, patch_size=32, + embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_large_3d_patch16_224(**kwargs): + model_kwargs = dict(patch_size=16, embed_dim=1024, + depth=24, num_heads=16, hidden_dim=4096, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_large_1d_patch16_384(**kwargs): + model_kwargs = dict(img_size=384, patch_size=16, + embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_large_1d_patch32_224(**kwargs): + model_kwargs = dict(patch_size=32, embed_dim=1024, + depth=24, num_heads=16, hidden_dim=4096, **kwargs) + return _create_vit_model(**model_kwargs) + + +@MODELS.register_module +def vit_large_1d_patch32_384(**kwargs): + model_kwargs = dict(img_size=384, patch_size=32, + embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs) + return _create_vit_model(**model_kwargs) diff --git a/tests/test_layers/test_1d/common.py b/tests/test_layers/test_1d/common.py index 64d4601cb6aa..a17cae9d316d 100644 --- a/tests/test_layers/test_1d/common.py +++ b/tests/test_layers/test_1d/common.py @@ -6,8 +6,9 @@ DEPTH = 2 BATCH_SIZE = 8 SEQ_LENGTH = 8 +IMG_SIZE = 16 HIDDEN_SIZE = 8 - +NUM_CLASSES = 10 def check_equal(A, B): assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index e89cfe9725af..2d6436b05a62 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -5,7 +5,7 @@ from colossalai.core import global_context as gpc from colossalai.initialize import init_dist -from test_layer import check_linear_col, check_linear_row +from test_layer import * CONFIG = dict( parallel=dict( @@ -19,20 +19,24 @@ def check_layer(): + # print_rank_0('start check_linear_col') check_linear_col() check_linear_row() - # check_attention() - # check_mlp() - + check_attention() + check_mlp() + check_patch_embedding() + check_embed() + check_head() @pytest.mark.dist @pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") -def test_2d(): +def test_1d(): init_dist(config=CONFIG) gpc.set_seed() + check_layer() gpc.destroy() if __name__ == '__main__': - test_2d() + test_1d() diff --git a/tests/test_layers/test_1d/test_layer.py b/tests/test_layers/test_1d/test_layer.py index 59551a5cac32..3fa9eee9dc94 100644 --- a/tests/test_layers/test_1d/test_layer.py +++ b/tests/test_layers/test_1d/test_layer.py @@ -1,14 +1,14 @@ +from colossalai.nn.optimizer.zero_redundancy_optimizer_level_2 import print_rank_msg +from tests.test_layers.test_3d.common import IMG_SIZE import torch import torch.distributed as dist from torch.nn import Parameter - +import time from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn import Linear1D_Col, Linear1D_Row -# TransformerMLP1D, \ -# TransformerSelfAttention1D, TransformerEncoderLayer1D +from colossalai.nn import Linear1D_Col, Linear1D_Row, TransformerMLP1D, TransformerSelfAttention1D, ViTMLP1D, ViTSelfAttention1D, ViTPatchEmbedding1D, ViTHead1D, ViTTokenFuser1D from colossalai.utils import get_current_device, print_rank_0 -from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal +from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE def check_linear_col(): @@ -142,70 +142,271 @@ def check_linear_row(): print_rank_0('linear_row no parallel_input backward: pass') -# -# def check_attention(): -# device = get_current_device() -# dtype = torch.float32 -# INPUT_SIZE = HIDDEN_SIZE -# NUM_ATTENTION_HEADS = 2 -# -# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) -# -# layer = TransformerSelfAttention1D( -# 1, -# HIDDEN_SIZE // NUM_ATTENTION_HEADS, -# HIDDEN_SIZE, -# NUM_ATTENTION_HEADS, -# 0.5 -# ) -# -# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) -# A_master = torch.randn(A_shape, dtype=dtype, device=device) -# torch.distributed.broadcast(A_master, src=0) -# A = A_master.clone() -# A.requires_grad = True -# -# mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH) -# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) -# -# out = layer(A, attention_mask) -# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) -# print_rank_0('self attention forward: pass') -# -# grad_shape = out.shape -# grad = torch.randn(grad_shape, dtype=dtype, device=device) -# -# out.backward(grad) -# assert A.grad.shape == A.shape -# print_rank_0('self attention backward: pass') -# -# -# def check_mlp(): -# device = get_current_device() -# dtype = torch.float32 -# INPUT_SIZE = HIDDEN_SIZE -# -# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) -# -# layer = TransformerMLP1D( -# HIDDEN_SIZE, -# HIDDEN_SIZE, -# 4.0 -# ) -# -# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) -# A_master = torch.randn(A_shape, dtype=dtype, device=device) -# torch.distributed.broadcast(A_master, src=0) -# A = A_master.clone() -# A.requires_grad = True -# -# out = layer(A) -# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) -# print_rank_0('mlp forward: pass') -# -# grad_shape = out.shape -# grad = torch.randn(grad_shape, dtype=dtype, device=device) -# -# out.backward(grad) -# assert A.grad.shape == A.shape -# print_rank_0('mlp backward: pass') +class Testvithead(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + x = x[:, 0] + x = self.linear(x) + return x + +def check_head(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + head = ViTHead1D(INPUT_SIZE,NUM_CLASSES,dtype=dtype) + torch.nn.init.zeros_(head.linear.bias) + torch.nn.init.ones_(head.linear.weight) + head = head.to(device) + + layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True) + torch.nn.init.zeros_(layer.linear.bias) + torch.nn.init.ones_(layer.linear.weight) + layer = layer.to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + fwd_start = time.time() + out = head(A) + fwd_end = time.time() + print_rank_0( + 'head forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start)) + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer(A_master) + # C = torch.chunk(C_master, DEPTH, dim=0)[i] + print_rank_msg('Rank {} head forward: {}'.format(i, check_equal(out, C_master))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, + dtype=dtype, + device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + # grad = torch.chunk(grad_master, DEPTH, dim=0)[i] + + # bwd_start = time.time() + out.backward(grad_master) + # bwd_end = time.time() + # print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start), + # logger) + + C_master.backward(grad_master) + A_grad = A_master.grad + # if j == 0: + print_rank_0('Rank {} head backward (input_grad): {}'.format( + i, check_equal(A_grad, A.grad))) + + + +class Testvitembed(torch.nn.Module): + def __init__(self, img_size: int, patch_size: int, in_chans: int, + embed_size: int, drop_prob: float) -> None: + super().__init__() + self.proj = torch.nn.Conv2d(in_chans, + embed_size, + kernel_size=patch_size, + stride=patch_size) + num_patches = (img_size // patch_size)**2 + self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_size)) + self.pos_embed = torch.nn.Parameter( + torch.zeros(1, num_patches + 1, embed_size)) + self.pos_drop = torch.nn.Dropout(drop_prob) + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + return x + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = ViTPatchEmbedding1D(IMG_SIZE, 4, HIDDEN_SIZE) + layer2 = ViTTokenFuser1D(IMG_SIZE, 4, HIDDEN_SIZE) + torch.nn.init.zeros_(layer.proj.bias) + torch.nn.init.ones_(layer.proj.weight) + torch.nn.init.ones_(layer2.cls_token) + torch.nn.init.ones_(layer2.pos_embed) + layer = layer.to(device) + layer2 = layer2.to(device) + + layer_master = Testvitembed(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.) + torch.nn.init.zeros_(layer_master.proj.bias) + torch.nn.init.ones_(layer_master.proj.weight) + torch.nn.init.ones_(layer_master.cls_token) + torch.nn.init.ones_(layer_master.pos_embed) + layer_master = layer_master.to(device) + + A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + fwd_start = time.time() + out = layer2(layer(A)) + fwd_end = time.time() + print_rank_0( + 'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start)) + # out_cls = out[:, 0] + # out_tensor = out[:, 1:] + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + # if j == 0: + # C_cls = C_master[:, 0] + # C_cls = torch.chunk(C_cls, DEPTH, dim=0)[i] + # C_cls = torch.chunk(C_cls, DEPTH, dim=-1)[k] + # logger.info('Rank {} embed forward (cls): {}'.format( + # rank, check_equal(out_cls, C_cls))) + # C = C_master[:, 1:] + print_rank_msg('Rank {} embed forward: {}'.format(i, check_equal(out, C_master))) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, + dtype=dtype, + device=get_current_device()) + torch.distributed.broadcast(grad_master, src=0) + # cls_grad = grad_master[:, 0] + # cls_grad = torch.chunk(cls_grad, DEPTH, dim=0)[i] + # cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[k] + # grad = grad_master[:, 1:] + # grad = torch.cat((torch.unsqueeze(cls_grad, 1), grad), dim=1) + bwd_start = time.time() + out.backward(grad_master) + bwd_end = time.time() + print_rank_0( + 'embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start)) + + C_master.backward(grad_master) + + A_grad = A_master.grad + print_rank_msg('Rank {} embed backward (input_grad): {}'.format(i, check_equal(A_grad, A.grad))) + + print_rank_0('Rank {} embed backward (cls_grad): {}'.format( + i, check_equal(layer_master.cls_token.grad, layer2.cls_token.grad))) + + print_rank_0('Rank {} embed backward (pos_embed_grad): {}'.format( + i, check_equal(layer_master.pos_embed.grad, layer2.pos_embed.grad))) + + print_rank_msg('Rank {} embed backward (proj_weight_grad): {}'.format( + i, check_equal(layer_master.proj.weight.grad, layer.proj.weight.grad))) + + print_rank_msg('Rank {} embed backward (proj_bias_grad): {}'.format( + i, check_equal(layer_master.proj.bias.grad, layer.proj.bias.grad))) + + return fwd_end - fwd_start, bwd_end - bwd_start + +def check_attention(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + NUM_ATTENTION_HEADS = 2 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + + layer = ViTSelfAttention1D( + HIDDEN_SIZE, + NUM_ATTENTION_HEADS, + 0.5, + 0.5 + ).to(device=device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH) + attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device) + + out = layer(A) + assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + print_rank_0('self attention forward: pass') + + grad_shape = out.shape + grad = torch.randn(grad_shape, dtype=dtype, device=device) + + out.backward(grad) + assert A.grad.shape == A.shape + print_rank_0('self attention backward: pass') + + +def check_mlp(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = ViTMLP1D( + HIDDEN_SIZE, + 4.0 + ).to(device=device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + print_rank_0('mlp forward: pass') + + grad_shape = out.shape + grad = torch.randn(grad_shape, dtype=dtype, device=device) + + out.backward(grad) + assert A.grad.shape == A.shape + print_rank_0('mlp backward: pass') + +def check_patch_embedding(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = 4 + PATCH_SIZE = 2 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = ViTPatchEmbedding1D( + INPUT_SIZE, + PATCH_SIZE, + HIDDEN_SIZE, + ).to(device=device) + + A_shape = (BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + print('output size: ',out.size()) + assert out.shape == (BATCH_SIZE, 4, HIDDEN_SIZE) + print_rank_0('patch embedding forward: pass') + + grad_shape = out.shape + grad = torch.randn(grad_shape, dtype=dtype, device=device) + + out.backward(grad) + assert A.grad.shape == A.shape + print_rank_0('patch embedding backward: pass') diff --git a/tests/test_models/test_vision_transformer/configs/vit_1d.py b/tests/test_models/test_vision_transformer/configs/vit_1d.py new file mode 100644 index 000000000000..40e28f6f186f --- /dev/null +++ b/tests/test_models/test_vision_transformer/configs/vit_1d.py @@ -0,0 +1,137 @@ +import os +from pathlib import Path + +BATCH_SIZE = 512 +IMG_SIZE = 32 +PATCH_SIZE = 4 +DIM = 512 +NUM_ATTENTION_HEADS = 8 +NUM_CLASSES = 10 +DEPTH = 6 +LOG_NAME = 'vit1D_cifar10_tp=2_selfattention_V2' + +# # ViT Base +# BATCH_SIZE = 512 +# IMG_SIZE = 224 +# PATCH_SIZE = 16 +# DIM = 384 +# NUM_ATTENTION_HEADS = 6 +# NUM_CLASSES = 100 +# DEPTH = 12 +# LOG_NAME = 'vit1D_imagenet100' + +train_data = dict( + dataset=dict( + type='CIFAR10Dataset', + root=Path(os.environ['DATA']), + download = True, + transform_pipeline=[ + dict(type='RandomCrop', size=IMG_SIZE, padding=4), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010]), + ]), + dataloader=dict(batch_size=BATCH_SIZE, + pin_memory=True, + num_workers=4, + shuffle=True)) + +test_data = dict( + dataset=dict( + type='CIFAR10Dataset', + root=Path(os.environ['DATA']), + train=False, + transform_pipeline=[ + dict(type='Resize', size=IMG_SIZE), + dict(type='ToTensor'), + dict(type='Normalize', + mean=[0.4914, 0.4822, 0.4465], + std=[0.2023, 0.1994, 0.2010]), + ]), + dataloader=dict(batch_size=400, + pin_memory=True, + num_workers=4, + shuffle=True)) + +optimizer = dict(type='Adam', lr=0.001, weight_decay=0) + +loss = dict(type='CrossEntropyLoss1D', ) + +model = dict( + type='VisionTransformerFromConfig', + embedding_cfg=dict( + type='ViTPatchEmbedding1D', + img_size=IMG_SIZE, + patch_size=PATCH_SIZE, + embed_dim=DIM, + ), + token_fusion_cfg=dict(type='ViTTokenFuser1D', + img_size=IMG_SIZE, + patch_size=PATCH_SIZE, + embed_dim=DIM, + drop_rate=0.1), + norm_cfg=dict( + type='LayerNorm', + normalized_shape=DIM, + eps=1e-6, + ), + block_cfg=dict( + type='ViTBlock', + attention_cfg=dict( + type='ViTSelfAttention1DV2', + hidden_size=DIM, + num_attention_heads=NUM_ATTENTION_HEADS, + attention_dropout_prob=0., + hidden_dropout_prob=0.1, + ), + droppath_cfg=dict(type='VanillaViTDropPath', ), + mlp_cfg=dict(type='ViTMLP1D', + in_features=DIM, + dropout_prob=0.1, + mlp_ratio=1), + norm_cfg=dict( + type='LayerNorm', + normalized_shape=DIM, + eps=1e-6, + ), + ), + head_cfg=dict( + type='ViTHead1D', + hidden_size=DIM, + num_classes=NUM_CLASSES, + ), + embed_dim=DIM, + depth=DEPTH, + drop_path_rate=0., +) + +parallel = dict( + pipeline=dict(size=1), + tensor=dict(size=2, mode='1d'), +) + +hooks = [ + dict(type='LogMetricByEpochHook'), + # dict(type='LogTimingByEpochHook'), + # dict(type='LogMemoryByEpochHook'), + dict(type='TensorboardHook', log_dir=f'./tests/test_models/test_vision_transformer/test_vit_1d/tb_logs_{LOG_NAME}'), + dict( + type='Accuracy1DHook', + ), + dict(type='LossHook'), + # dict(type='TensorboardHook', log_dir='./tfb_logs'), + # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), + # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') +] + +logging = dict( + root_path=f"./tests/test_models/test_vision_transformer/test_vit_1d/{LOG_NAME}" +) + +lr_scheduler = dict(type='LinearWarmupLR', warmup_epochs=5) + +num_epochs = 70 + +seed = 42 \ No newline at end of file diff --git a/tests/test_models/test_vision_transformer/test_vit_1d/test_vit_1d.py b/tests/test_models/test_vision_transformer/test_vit_1d/test_vit_1d.py new file mode 100644 index 000000000000..6e7cef3af264 --- /dev/null +++ b/tests/test_models/test_vision_transformer/test_vit_1d/test_vit_1d.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from pathlib import Path + +import pytest +import torch.autograd + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine import Engine +from colossalai.logging import get_global_dist_logger +from colossalai.nn.layer._parallel_utilities import _gather +from colossalai.trainer import Trainer + +CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_1d.py') + + +def eval(engine): + engine.eval() + accumulated_loss = 0 + correct_sum = 0 + total_sum = 0 + + for i in range(engine.schedule.num_steps): + output, label, loss = engine.step() + accumulated_loss += loss.detach().cpu().numpy() + if isinstance(output, (list, tuple)): + output = output[0] + if isinstance(label, (list, tuple)): + label = label[0] + output = torch.argmax(output, dim=-1) + correct = torch.sum(label == output) + correct_sum += correct + total_sum += label.size(0) + avg_loss = accumulated_loss / engine.schedule.num_steps + return correct_sum, total_sum, avg_loss + + +def train(engine): + engine.train() + accumulated_loss = 0 + + for i in range(engine.schedule.num_steps): + output, label, loss = engine.step() + accumulated_loss += loss.detach().cpu().numpy() + avg_loss = accumulated_loss / engine.schedule.num_steps + return avg_loss + + +@pytest.mark.dist +@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus") +def test_1d_parallel_vision_transformer(): + # init dist + model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize( + CONFIG_PATH) + logger = get_global_dist_logger() + + engine = Engine(model=model, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + criterion=criterion, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + schedule=schedule) + + logger.info('start training') + for epoch in range(gpc.config.num_epochs): + train_loss = train(engine) + logger.info(f'epoch {epoch} - train loss: {train_loss}') + + if epoch % 2 == 0: + correct_sum, total_sum, eval_loss = eval(engine) + logger.info( + f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, ' + f'correct: {correct_sum}, acc: {correct_sum / total_sum}') + +def train(): + model, train_dataloader, test_dataloader, criterion, \ + optimizer, schedule, lr_scheduler = colossalai.initialize(CONFIG_PATH) + + logger = get_global_dist_logger() + engine = Engine(model=model, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + criterion=criterion, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + schedule=schedule) + logger.info("Engine is built", ranks=[0]) + + trainer = Trainer(engine=engine, hooks_cfg=gpc.config.hooks, verbose=True) + logger.info("Trainer is built", ranks=[0]) + + logger.info("Train start", ranks=[0]) + trainer.fit(train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + max_epochs=gpc.config.num_epochs, + display_progress=True, + test_interval=1) + +if __name__ == '__main__': + train()