From fb7ea3732801cfadae84ee8c3ac465173060f305 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 19 Jun 2023 19:00:37 +0800 Subject: [PATCH 01/10] first v of vit shardformer --- colossalai/shardformer/policies/vit.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 colossalai/shardformer/policies/vit.py diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py new file mode 100644 index 000000000000..8769174e3704 --- /dev/null +++ b/colossalai/shardformer/policies/vit.py @@ -0,0 +1,10 @@ +rom typing import Dict, Union + +import torch.nn as nn + +from transformers.models.vit.modeling_vit import ViTModel + +from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + From 162dd1ec21091cf4cda4fcd7fb22fe63ae29ca0d Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 27 Jun 2023 15:52:41 +0800 Subject: [PATCH 02/10] keep vit --- colossalai/shardformer/policies/vit.py | 50 ++++++++++++++++- .../test_model/test_shard_vit.py | 55 +++++++++++++++++++ 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 tests/test_shardformer/test_model/test_shard_vit.py diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 8769174e3704..96527ce1350d 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -2,9 +2,57 @@ import torch.nn as nn -from transformers.models.vit.modeling_vit import ViTModel +from transformers.models.vit.modeling_vit import ViTModel, ViTLayer from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +class ViTPolicy(Policy): + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + return { + ViTLayer:ModulePolicyDescription( + + ), + ViTModel: + ModulePolicyDescription( + attribute_replacement{ + + } + ), + + } + + @staticmethod + def embedding() -> List: + return[ + Embedding_Layer( + suffix="", + weight="weight", + replace_layer=col_nn.Embedding1D, + ) + ] + + @staticmethod + def dropout(): + return [Dropout_Layer( + suffix="dropout", + p="p", + replace_layer=col_nn.Dropout1D, + )] + + + + diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py new file mode 100644 index 000000000000..d5d71d9e29fe --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -0,0 +1,55 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output) + + # do backward + org_loss.backward() + shard_loss.backward() + + # check grad + org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +if __name__ == "__main__": + test_vit() From 4047e4adf04e086e18e9dbd68e17579bcccf4c1e Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 20 Jun 2023 16:01:03 +0800 Subject: [PATCH 03/10] update --- colossalai/shardformer/layer/embedding1d.py | 149 +++++++ colossalai/shardformer/layer/layernorm1d.py | 73 ++++ colossalai/shardformer/layer/linear1d.py | 346 ++++++++++++++++ colossalai/shardformer/layer/linearconv1d.py | 377 ++++++++++++++++++ .../shardformer/layer/parallelmodule.py | 35 ++ .../layer/vocabparallelembedding1d.py | 170 ++++++++ colossalai/shardformer/policies/bert.py | 9 + .../test_model/test_shard_bert.py | 19 + .../test_model/test_shard_t5.py | 5 +- 9 files changed, 1182 insertions(+), 1 deletion(-) create mode 100644 colossalai/shardformer/layer/embedding1d.py create mode 100644 colossalai/shardformer/layer/layernorm1d.py create mode 100644 colossalai/shardformer/layer/linear1d.py create mode 100644 colossalai/shardformer/layer/linearconv1d.py create mode 100644 colossalai/shardformer/layer/parallelmodule.py create mode 100644 colossalai/shardformer/layer/vocabparallelembedding1d.py diff --git a/colossalai/shardformer/layer/embedding1d.py b/colossalai/shardformer/layer/embedding1d.py new file mode 100644 index 000000000000..1108d5d6a936 --- /dev/null +++ b/colossalai/shardformer/layer/embedding1d.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Callable, List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise +from colossalai.utils.cuda import get_current_device + +from ._operation import gather_forward_split_backward +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class Embedding1D(ParallelModule): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.process_group = process_group + self.num_partitions = dist.get_world_size(process_group) + self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + # self.gather_output = gather_output + + if device is None: + device = get_current_device() + + self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding + + def reset_parameters(self, weight_initializer) -> None: + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + + return output diff --git a/colossalai/shardformer/layer/layernorm1d.py b/colossalai/shardformer/layer/layernorm1d.py new file mode 100644 index 000000000000..78bd64cfb504 --- /dev/null +++ b/colossalai/shardformer/layer/layernorm1d.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from collections import OrderedDict + +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init +from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule +from colossalai.utils.checkpointing import broadcast_state_dict + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/layer/linear1d.py b/colossalai/shardformer/layer/linear1d.py new file mode 100644 index 000000000000..d59d32df824e --- /dev/null +++ b/colossalai/shardformer/layer/linear1d.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device + +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + sharded_weight = shard_rowwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + if bias: + sharded_bias = shard_colwise(module.bias.data, process_group) + linear_1d.bias.copy_(sharded_bias) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + sharded_weight = shard_colwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + + if bias: + linear_1d.bias.copy_(module.bias.data) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/linearconv1d.py b/colossalai/shardformer/layer/linearconv1d.py new file mode 100644 index 000000000000..4a5cb0707900 --- /dev/null +++ b/colossalai/shardformer/layer/linearconv1d.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device + +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LinearConv1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = LinearConv1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + + # first rearange the order of weight and bias + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_cast) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1) + rearanged_weight_chunks = [weight_chunks[i] for i in new_order] + rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) + sharded_weight = shard_colwise(rearanged_weight, process_group) + linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + + if bias: + bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + rearanged_bias_chunks = [bias_chunks[i] for i in new_order] + rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) + sharded_bias = shard_colwise(rearanged_bias, process_group) + linear_1d.bias.copy_(sharded_bias.contiguous()) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class LinearConv1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = LinearConv1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + + # first rearange the order of weight and bias + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_cast) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0) + rearanged_weight_chunks = [weight_chunks[i] for i in new_order] + rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0) + sharded_weight = shard_rowwise(rearanged_weight, process_group) + linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + + if bias: + bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + rearanged_bias_chunks = [bias_chunks[i] for i in new_order] + rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) + linear_1d.bias.copy_(rearanged_bias.contiguous()) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/parallelmodule.py b/colossalai/shardformer/layer/parallelmodule.py new file mode 100644 index 000000000000..3d19bbea7e47 --- /dev/null +++ b/colossalai/shardformer/layer/parallelmodule.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import List, Union + +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class ParallelModule(nn.Module, ABC): + + @abstractmethod + def from_native_module(module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + pass diff --git a/colossalai/shardformer/layer/vocabparallelembedding1d.py b/colossalai/shardformer/layer/vocabparallelembedding1d.py new file mode 100644 index 000000000000..4c325c68421b --- /dev/null +++ b/colossalai/shardformer/layer/vocabparallelembedding1d.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from collections import OrderedDict +from typing import Callable, List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.context import ParallelMode, seed +from colossalai.nn import init as init +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_rowwise +from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict + +from ._operation import reduce_input +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.process_group = process_group + + tensor_parallel_size = dist.get_world_size(group=process_group) + tensor_parallel_rank = dist.get_rank(group=process_group) + + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings = self.num_embeddings_per_partition + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + + # ensure only one process group is used + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create the parallel module + vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + *args, + **kwargs) + with torch.no_grad(): + # shard and slice the weight along the vocabulary(num_embeddings) dimension + # the shape of the weight is (num_embeddings, embedding_dim) + shard_weight = shard_rowwise(module.weight.data, process_group) + vocab_embedding_1d.weight.data.copy_(shard_weight) + + return vocab_embedding_1d + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, self.process_group) + return output diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 7b0eaa5d8ab1..93d004d7b10f 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -288,6 +288,15 @@ def module_policy(self): module_policy.update(addon_module) return module_policy + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index ad98e3d073d4..bf828530415d 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -38,7 +38,26 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo def check_bert(rank, world_size, port): disable_existing_loggers() +<<<<<<< HEAD colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') +======= + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + forward_list = [ + BertForMaskedLM, + BertForPreTraining, + BertLMHeadModel, + + # TODO: do not work yet + # BertModel, + # BertForSequenceClassification + # BertForNextSentencePrediction, + ] + backward_lsit = [BertForMaskedLM, BertLMHeadModel] + + for model_fn in forward_list: + org_model, sharded_model = build_model(world_size, model_fn) + check_forward(org_model, sharded_model) +>>>>>>> 0cf164a2... update sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 2698d7675c8e..b5e0055801eb 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -45,7 +45,10 @@ def check_t5(rank, world_size, port): org_model, sharded_model = build_model(world_size, model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() + for model_fn in model_fn_list: + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model) + torch.cuda.empty_cache() @pytest.mark.dist From 326b3cde40f7af64c2c728dacf17f52b9e0ae3a7 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Sun, 25 Jun 2023 16:18:48 +0800 Subject: [PATCH 04/10] vit shard add vitattention vitlayer --- colossalai/shardformer/policies/vit.py | 109 +++++++++++++++++++------ 1 file changed, 83 insertions(+), 26 deletions(-) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 96527ce1350d..49815cd4bd63 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -2,9 +2,9 @@ import torch.nn as nn -from transformers.models.vit.modeling_vit import ViTModel, ViTLayer +from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention -from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D, LayerNorm1D, Dropout1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -23,35 +23,92 @@ def preprocess(self): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: return { - ViTLayer:ModulePolicyDescription( - - ), - ViTModel: + ViTEmbeddings: + ModulePolicyDescription( + attribute_replacement{}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ] + ), + ViTLayer: + ModulePolicyDescription( + attribute_replacement{}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="layernorm_before", + target_module=LayerNorm1D, + ), + SubModuleReplacementDescription( + suffix="layernorm_after", + target_module=LayerNorm1D, + ), + ] + ), + ViTAttention: ModulePolicyDescription( attribute_replacement{ - - } + "attention.num_attention_heads": + self.config.num_attention_heads//self.shard_config.tensor_parallel_size, + + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=Dropout1D, + ), + ], + ), + ViTModel: + ModulePolicyDescription( + attribute_replacement{}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="layernorm", + target_module=LayerNorm1D, + ) + ] ), - } - @staticmethod - def embedding() -> List: - return[ - Embedding_Layer( - suffix="", - weight="weight", - replace_layer=col_nn.Embedding1D, - ) - ] - - @staticmethod - def dropout(): - return [Dropout_Layer( - suffix="dropout", - p="p", - replace_layer=col_nn.Dropout1D, - )] From 113eb55dbdf3f928bc87140a5dd5b794526c29e2 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 26 Jun 2023 14:04:40 +0800 Subject: [PATCH 05/10] update num head shard para --- colossalai/lazy/lazy_init.py | 11 +-- colossalai/shardformer/policies/vit.py | 58 ++++++------ colossalai/tensor/comm_spec.py | 65 +++++++------- colossalai/tensor/d_tensor/comm_spec.py | 88 +++++++++++-------- colossalai/tensor/d_tensor/layout.py | 8 +- .../tensor/d_tensor/layout_converter.py | 71 +++++++-------- tests/test_device/test_device_mesh.py | 74 +++++++++++++++- tests/test_device/test_init_logical_pg.py | 16 ++-- tests/test_lazy/lazy_init_utils.py | 4 +- tests/test_lazy/test_distribute.py | 28 +++--- .../test_dtensor/test_comm_spec.py | 33 +++++-- .../test_tensor/test_dtensor/test_dtensor.py | 2 +- .../test_dtensor/test_layout_converter.py | 43 ++++++--- tests/test_tensor/test_shape_consistency.py | 7 +- tests/test_tensor/test_sharded_linear.py | 2 +- tests/test_tensor/test_sharding_spec.py | 2 +- 16 files changed, 313 insertions(+), 199 deletions(-) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 8b911407307c..1e45eced5f34 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -173,7 +173,7 @@ def materialize(self) -> torch.Tensor: self.clean() return _convert_cls(self, target) - def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: + def distribute(self, layout: Layout) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -537,10 +537,7 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, - device_mesh: DeviceMesh, - sharding_spec_dict: Dict[str, ShardingSpec], - verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -550,7 +547,7 @@ def distribute(module: nn.Module, """ def apply_fn(name: str, p: LazyTensor): - p.distribute(device_mesh, sharding_spec_dict[name]) + p.distribute(layout_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 49815cd4bd63..bd86381f99de 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -36,65 +36,59 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ViTLayer: ModulePolicyDescription( - attribute_replacement{}, + attribute_replacement{ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size//self.shard_config.tensor_parallel_size, + }, param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( - suffix="intermediate.dense", + suffix="attention.attention.query", target_module=Linear1D_Col, ), SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, + suffix="attention.attention.key", + target_module=Linear1D_Col, ), SubModuleReplacementDescription( - suffix="output.dropout", - target_module=Dropout1D, + suffix="attention.attention.value", + target_module=Linear1D_Col, ), SubModuleReplacementDescription( - suffix="layernorm_before", - target_module=LayerNorm1D, + suffix="attention.attention.dropout", + target_module=Dropout1D, ), SubModuleReplacementDescription( - suffix="layernorm_after", - target_module=LayerNorm1D, + suffix="attention.output.dense", + target_module=Linear1D_Row, ), - ] - ), - ViTAttention: - ModulePolicyDescription( - attribute_replacement{ - "attention.num_attention_heads": - self.config.num_attention_heads//self.shard_config.tensor_parallel_size, - - }, - param_replacement=[], - sub_module_replacement=[ SubModuleReplacementDescription( - suffix="attention.query", - target_module=Linear1D_Col, + suffix="attention.output.dropout", + target_module=Dropout1D, ), SubModuleReplacementDescription( - suffix="attention.key", + suffix="intermediate.dense", target_module=Linear1D_Col, ), SubModuleReplacementDescription( - suffix="attention.value", - target_module=Linear1D_Col, + suffix="output.dense", + target_module=Linear1D_Row, ), SubModuleReplacementDescription( - suffix="attention.dropout", + suffix="output.dropout", target_module=Dropout1D, ), SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, + suffix="layernorm_before", + target_module=LayerNorm1D, ), SubModuleReplacementDescription( - suffix="output.dropout", - target_module=Dropout1D, + suffix="layernorm_after", + target_module=LayerNorm1D, ), - ], + ] ), ViTModel: ModulePolicyDescription( diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 204f81343199..fde819e5a379 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -34,48 +34,51 @@ def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() - process_group = process_groups[comm_spec.logical_process_axis] - - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) - start = length * dist.get_rank(process_group) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() - process_group = process_groups[comm_spec.logical_process_axis] - world_size = dist.get_world_size(process_group) - - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size - new_shape = torch.Size(new_shape) - output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // world_size - input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() - process_group = process_groups[comm_spec.logical_process_axis] - - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): @@ -411,7 +414,7 @@ def __init__(self, self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten() + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 79b2e3ef936a..159125fa16db 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -24,12 +24,12 @@ class CommSpec: ''' Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_group_dict to determine the process groups, gather_dim and shard_dim + communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. - process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. + process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. @@ -37,7 +37,7 @@ class CommSpec: def __init__(self, comm_pattern: CollectiveCommPattern, - process_group_dict: Dict, + process_groups_dict: Dict, gather_dim: int = None, shard_dim: int = None, logical_process_axis: int = None): @@ -45,7 +45,7 @@ def __init__(self, self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_group_dict = process_group_dict + self.process_groups_dict = process_groups_dict def __repr__(self): res_list = ["CommSpec:("] @@ -92,56 +92,68 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - world_size = dist.get_world_size(process_group) - tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) - start = length * dist.get_rank(process_group) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - world_size = dist.get_world_size(process_group) - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size - new_shape = torch.Size(new_shape) - output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // world_size - input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -257,7 +269,7 @@ def symbolic(graph, input_): def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_group_dict=comm_spec.process_group_dict, + process_groups_dict=comm_spec.process_groups_dict, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index a35b2f43e44b..93588b5162f0 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -18,10 +18,12 @@ class Layout: global_shape: the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, + entire_shape: torch.Size): self.device_mesh = device_mesh + self.device_type = device_type self.sharding_spec = sharding_spec - self.global_shape = global_shape + self.entire_shape = entire_shape self._sanity_check() def __hash__(self) -> int: @@ -53,7 +55,7 @@ def _sanity_check(self): # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.global_shape[dim] + tensor_dim_size = self.entire_shape[dim] num_devices = 1 for element in shard_list: diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 528ed7901c4f..14f9c4561622 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,8 +3,10 @@ from dataclasses import dataclass from typing import Dict, List, Tuple +import numpy as np import torch +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -35,9 +37,6 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): - """ - LayoutConverter is a singleton class which converts the layout of a distributed tensor. - """ def __init__(self): self._options = None @@ -80,14 +79,15 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec, - global_shape=global_shape) + entire_shape=entire_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -100,12 +100,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - - # the key of the dict is the axis - # the value is the process group - current_rank = source_layout.device_mesh._global_rank_of_current_process - process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] - + process_groups_dict = source_layout.device_mesh.process_groups_dict for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -123,7 +118,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_group_dict=process_group_dict, + process_groups_dict=process_groups_dict, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, @@ -134,7 +129,8 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -159,14 +155,15 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec, - global_shape=global_shape) + entire_shape=entire_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -179,12 +176,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - - # the key of the dict is the axis - # the value is the process group - current_rank = source_layout.device_mesh._global_rank_of_current_process - process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] - + process_groups_dict = source_layout.device_mesh.process_groups_dict source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -225,7 +217,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, + process_groups_dict, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -248,7 +240,8 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -273,15 +266,16 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec, - global_shape=global_shape) + entire_shape=entire_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -295,11 +289,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - - # the key of the dict is the axis - # the value is the process group - current_rank = source_layout.device_mesh._global_rank_of_current_process - process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + process_groups_dict = source_layout.device_mesh.process_groups_dict # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] @@ -327,7 +317,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, + process_groups_dict, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -338,7 +328,8 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -396,7 +387,7 @@ def layout_converting(self, source_layout: Layout, # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -404,14 +395,16 @@ def layout_converting(self, source_layout: Layout, # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - global_shape=global_shape) + entire_shape=entire_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - global_shape=global_shape) + entire_shape=entire_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -500,19 +493,21 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - global_shape = (4, 4, 4) + entire_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - global_shape=global_shape) + entire_shape=entire_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - global_shape=global_shape) + entire_shape=entire_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 590d6966bff6..0ae92649ad3e 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -8,16 +8,82 @@ def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16) + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.global_rank_to_local_rank(5) == [1, 1] - assert device_mesh.global_rank_to_local_rank(11) == [2, 3] - assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] + assert device_mesh.convert_map[5] == [1, 1] + assert device_mesh.convert_map[11] == [2, 3] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] + assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] + assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + + +def check_1d_device_mesh(): + # check for 1D device mesh + process_group = dist.GroupMember.WORLD + device_mesh = DeviceMesh.from_process_group(process_group) + + # checks + assert device_mesh.shape == [4] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' + assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' + assert device_mesh.is_initialized + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_2d_device_mesh(): + # create process group for 2D device mesh + first_row_ranks = [0, 1] + second_row_ranks = [2, 3] + first_col_ranks = [0, 2] + second_col_ranks = [1, 3] + + first_row_pg = dist.new_group(first_row_ranks, backend='nccl') + second_row_pg = dist.new_group(second_row_ranks, backend='nccl') + first_col_pg = dist.new_group(first_col_ranks, backend='nccl') + second_col_pg = dist.new_group(second_col_ranks, backend='nccl') + + # check for + current_rank = dist.get_rank() + + if current_rank in first_row_ranks: + row_pg = first_row_pg + else: + row_pg = second_row_pg + + if current_rank in first_col_ranks: + col_pg = first_col_pg + else: + col_pg = second_col_pg + + device_mesh = DeviceMesh.from_process_group([col_pg, row_pg]) + + # checks + assert device_mesh.shape == [2, 2] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' + assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' + assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_init_from_process_group(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_device_mesh_from_process_group(): + spawn(check_init_from_process_group, 4) def check_1d_device_mesh(): diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 7c6339eff67e..2b7060c4846a 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -20,12 +20,16 @@ def check_layer(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - - for axis in range(len(mesh_shape)): - tensor = torch.ones(4).cuda() - pg = device_mesh.get_process_group(axis=axis) - dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) - assert tensor.equal(tensor_to_check) + logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} + logical_process_groups = device_mesh.process_groups_dict + + for mesh_dim, pgs in logical_pg_dict.items(): + for index, pg in enumerate(pgs): + if rank in pg: + tensor = torch.ones(4).cuda() + group = logical_process_groups[mesh_dim][index][1] + dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) + assert tensor.equal(tensor_to_check) gpc.destroy() diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 73c3c5422d8a..3879363bcd1b 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -6,7 +6,6 @@ import torch from packaging import version -from colossalai.device.device_mesh import DeviceMesh from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.tensor.d_tensor import to_global from colossalai.tensor.d_tensor.layout import Layout @@ -83,8 +82,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, print(f'{model.__class__.__name__} pass') -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, - sharding_spec_dict: dict) -> None: +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index 622d9deb601d..94c8612a8b4c 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -26,19 +26,23 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: return dim -def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: +def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: shard_dim = find_shard_dim(original_tensor.shape) dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - return target_sharding_spec + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=target_sharding_spec, + entire_shape=original_tensor.shape) + return layout def _get_current_name(prefix: str, name: str) -> str: return f'{prefix}.{name}'.lstrip('.') -def generate_sharding_spec_dict(model: nn.Module) -> dict: - sharding_spec_dict = {} +def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: + layout_dict = {} @torch.no_grad() def generate_recursively(module: nn.Module, prefix: str = ''): @@ -49,17 +53,17 @@ def generate_recursively(module: nn.Module, prefix: str = ''): # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if isinstance(param, LazyTensor): - sharding_spec = make_sharding_spec(param) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec + layout = make_layout(device_mesh, param) + layout_dict[_get_current_name(prefix, name)] = layout for name, buf in module.named_buffers(recurse=False): if isinstance(buf, LazyTensor): - sharding_spec = make_sharding_spec(buf) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec + layout = make_layout(device_mesh, buf) + layout_dict[_get_current_name(prefix, name)] = layout generate_recursively(model) - return sharding_spec_dict + return layout_dict @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @@ -81,9 +85,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ctx = LazyInitContext() with ctx: deferred_model = model_fn() - sharding_spec_dict = generate_sharding_spec_dict(deferred_model) - ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) + layout_dict = generate_layout_dict(deferred_model, device_mesh) + ctx.distribute(deferred_model, layout_dict, verbose=True) + assert_dist_model_equal(model, deferred_model, layout_dict) def run_dist(rank, world_size, port) -> None: diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 95fcd2aaf8f3..958eabb65fac 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -122,6 +122,23 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) +def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): + # tensor to comm + tensor_to_comm = torch.ones(2, 2).cuda() * rank + + # reduce through logical process axis 0 at flatten device mesh + # tensor to check + # tensor([[6., 6.], + # [6., 6.]]) + tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() + + # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) + comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -133,22 +150,24 @@ def check_comm(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - - process_group_dict = device_mesh._process_group_dict[rank] + process_groups_dict = device_mesh.process_groups_dict # test all gather - check_all_gather(process_group_dict, rank) + check_all_gather(process_groups_dict, rank) # test shard - check_shard(process_group_dict, rank) + check_shard(process_groups_dict, rank) # test all to all - check_all_to_all(process_group_dict, rank) + check_all_to_all(process_groups_dict, rank) # test all reduce - check_all_reduce_fwd(process_group_dict, rank) - check_all_reduce_bwd(process_group_dict, rank) + check_all_reduce_fwd(process_groups_dict, rank) + check_all_reduce_bwd(process_groups_dict, rank) + flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict + # test all reduce in 1D flatten device mesh + check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) gpc.destroy() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 5a1aef79f332..8350fb3e7fe6 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port): else: raise ValueError(f'rank {rank} is not in the device mesh') - dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) + dtensor_from_local = distribute_tensor(original_tensor, new_layout) if rank == 0: assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5388fd901e09..d9dff8af933d 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -global_shape = torch.Size((64, 32, 16)) +entire_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4) +physical_mesh_id = torch.arange(0, 4).reshape(2, 2) mesh_shape = (2, 2) @@ -30,7 +30,10 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) @@ -46,7 +49,10 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) + layout_all2all = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_all2all, + entire_shape=entire_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) @@ -65,7 +71,10 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) + shard_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_shard, + entire_shape=entire_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) @@ -91,13 +100,19 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) @@ -122,7 +137,7 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - # checkout chached_spec_pairs_transform_path + # checkout cached_spec_pairs_transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence @@ -144,15 +159,21 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) - original_tensor = torch.rand(global_shape).cuda() + original_tensor = torch.rand(entire_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 859eef051256..6fe9ee292cd0 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,10 +1,9 @@ +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch - +from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -physical_mesh_id = torch.arange(0, 16) +physical_mesh_id = torch.arange(0, 16).reshape(2, 8) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index 9bd9805e9b8f..d66d4fec14d1 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): # the mesh is in the following topo # [[0, 1], # [2, 3]] - physical_mesh_id = torch.arange(0, 4) + physical_mesh_id = torch.arange(0, 4).reshape(2, 2) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) row_id = rank // 2 diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 5007c4141849..909c84ef0f0e 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -5,7 +5,7 @@ def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16) + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], From 4c47c5bdac704038d68adda27a0214cc14723c22 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Mon, 26 Jun 2023 18:19:56 +0800 Subject: [PATCH 06/10] finish test for vit --- colossalai/lazy/lazy_init.py | 11 ++- colossalai/shardformer/policies/vit.py | 4 +- colossalai/tensor/comm_spec.py | 65 +++++++------- colossalai/tensor/d_tensor/comm_spec.py | 88 ++++++++----------- colossalai/tensor/d_tensor/layout.py | 8 +- .../tensor/d_tensor/layout_converter.py | 71 ++++++++------- tests/test_device/test_device_mesh.py | 10 +-- tests/test_device/test_init_logical_pg.py | 16 ++-- tests/test_lazy/lazy_init_utils.py | 4 +- tests/test_lazy/test_distribute.py | 28 +++--- .../test_dtensor/test_comm_spec.py | 33 ++----- .../test_tensor/test_dtensor/test_dtensor.py | 2 +- .../test_dtensor/test_layout_converter.py | 43 +++------ tests/test_tensor/test_shape_consistency.py | 7 +- tests/test_tensor/test_sharded_linear.py | 2 +- tests/test_tensor/test_sharding_spec.py | 2 +- 16 files changed, 169 insertions(+), 225 deletions(-) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 1e45eced5f34..8b911407307c 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -173,7 +173,7 @@ def materialize(self) -> torch.Tensor: self.clean() return _convert_cls(self, target) - def distribute(self, layout: Layout) -> torch.Tensor: + def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -537,7 +537,10 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, + device_mesh: DeviceMesh, + sharding_spec_dict: Dict[str, ShardingSpec], + verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -547,7 +550,7 @@ def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> n """ def apply_fn(name: str, p: LazyTensor): - p.distribute(layout_dict[name]) + p.distribute(device_mesh, sharding_spec_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index bd86381f99de..03de6aac5de2 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,10 +1,10 @@ -rom typing import Dict, Union +from typing import Dict, Union import torch.nn as nn from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention -from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D, LayerNorm1D, Dropout1D +from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, LayerNorm1D, Dropout1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index fde819e5a379..204f81343199 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -34,51 +34,48 @@ def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): @@ -414,7 +411,7 @@ def __init__(self, self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.device_mesh = self.sharding_spec.device_mesh.flatten() self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 159125fa16db..79b2e3ef936a 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -24,12 +24,12 @@ class CommSpec: ''' Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim + communication method, process_group_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. - process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. @@ -37,7 +37,7 @@ class CommSpec: def __init__(self, comm_pattern: CollectiveCommPattern, - process_groups_dict: Dict, + process_group_dict: Dict, gather_dim: int = None, shard_dim: int = None, logical_process_axis: int = None): @@ -45,7 +45,7 @@ def __init__(self, self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_groups_dict = process_groups_dict + self.process_group_dict = process_group_dict def __repr__(self): res_list = ["CommSpec:("] @@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -269,7 +257,7 @@ def symbolic(graph, input_): def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_groups_dict=comm_spec.process_groups_dict, + process_group_dict=comm_spec.process_group_dict, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index 93588b5162f0..a35b2f43e44b 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -18,12 +18,10 @@ class Layout: global_shape: the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, - entire_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): self.device_mesh = device_mesh - self.device_type = device_type self.sharding_spec = sharding_spec - self.entire_shape = entire_shape + self.global_shape = global_shape self._sanity_check() def __hash__(self) -> int: @@ -55,7 +53,7 @@ def _sanity_check(self): # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.entire_shape[dim] + tensor_dim_size = self.global_shape[dim] num_devices = 1 for element in shard_list: diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 14f9c4561622..528ed7901c4f 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,10 +3,8 @@ from dataclasses import dataclass from typing import Dict, List, Tuple -import numpy as np import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): + """ + LayoutConverter is a singleton class which converts the layout of a distributed tensor. + """ def __init__(self): self._options = None @@ -79,15 +80,14 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -100,7 +100,12 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -118,7 +123,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_groups_dict=process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, @@ -129,8 +134,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -155,15 +159,14 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -176,7 +179,12 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -217,7 +225,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -240,8 +248,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -266,16 +273,15 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -289,7 +295,11 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] @@ -317,7 +327,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -328,8 +338,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -387,7 +396,7 @@ def layout_converting(self, source_layout: Layout, # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -395,16 +404,14 @@ def layout_converting(self, source_layout: Layout, # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -493,21 +500,19 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 0ae92649ad3e..43b1f4276e8a 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -8,18 +8,16 @@ def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.convert_map[5] == [1, 1] - assert device_mesh.convert_map[11] == [2, 3] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] - assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + assert device_mesh.global_rank_to_local_rank(5) == [1, 1] + assert device_mesh.global_rank_to_local_rank(11) == [2, 3] + assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] def check_1d_device_mesh(): diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 2b7060c4846a..7c6339eff67e 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -20,16 +20,12 @@ def check_layer(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} - logical_process_groups = device_mesh.process_groups_dict - - for mesh_dim, pgs in logical_pg_dict.items(): - for index, pg in enumerate(pgs): - if rank in pg: - tensor = torch.ones(4).cuda() - group = logical_process_groups[mesh_dim][index][1] - dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) - assert tensor.equal(tensor_to_check) + + for axis in range(len(mesh_shape)): + tensor = torch.ones(4).cuda() + pg = device_mesh.get_process_group(axis=axis) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) + assert tensor.equal(tensor_to_check) gpc.destroy() diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 3879363bcd1b..73c3c5422d8a 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -6,6 +6,7 @@ import torch from packaging import version +from colossalai.device.device_mesh import DeviceMesh from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.tensor.d_tensor import to_global from colossalai.tensor.d_tensor.layout import Layout @@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, print(f'{model.__class__.__name__} pass') -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, + sharding_spec_dict: dict) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index 94c8612a8b4c..622d9deb601d 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: return dim -def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: +def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: shard_dim = find_shard_dim(original_tensor.shape) dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - return layout + return target_sharding_spec def _get_current_name(prefix: str, name: str) -> str: return f'{prefix}.{name}'.lstrip('.') -def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: - layout_dict = {} +def generate_sharding_spec_dict(model: nn.Module) -> dict: + sharding_spec_dict = {} @torch.no_grad() def generate_recursively(module: nn.Module, prefix: str = ''): @@ -53,17 +49,17 @@ def generate_recursively(module: nn.Module, prefix: str = ''): # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if isinstance(param, LazyTensor): - layout = make_layout(device_mesh, param) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(param) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec for name, buf in module.named_buffers(recurse=False): if isinstance(buf, LazyTensor): - layout = make_layout(device_mesh, buf) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(buf) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec generate_recursively(model) - return layout_dict + return sharding_spec_dict @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ctx = LazyInitContext() with ctx: deferred_model = model_fn() - layout_dict = generate_layout_dict(deferred_model, device_mesh) - ctx.distribute(deferred_model, layout_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, layout_dict) + sharding_spec_dict = generate_sharding_spec_dict(deferred_model) + ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) + assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) def run_dist(rank, world_size, port) -> None: diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 958eabb65fac..95fcd2aaf8f3 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -122,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) -def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): - # tensor to comm - tensor_to_comm = torch.ones(2, 2).cuda() * rank - - # reduce through logical process axis 0 at flatten device mesh - # tensor to check - # tensor([[6., 6.], - # [6., 6.]]) - tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() - - # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) - comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) - tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) - - assert tensor_to_comm.equal(tensor_to_check) - - def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -150,24 +133,22 @@ def check_comm(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - process_groups_dict = device_mesh.process_groups_dict + + process_group_dict = device_mesh._process_group_dict[rank] # test all gather - check_all_gather(process_groups_dict, rank) + check_all_gather(process_group_dict, rank) # test shard - check_shard(process_groups_dict, rank) + check_shard(process_group_dict, rank) # test all to all - check_all_to_all(process_groups_dict, rank) + check_all_to_all(process_group_dict, rank) # test all reduce - check_all_reduce_fwd(process_groups_dict, rank) - check_all_reduce_bwd(process_groups_dict, rank) + check_all_reduce_fwd(process_group_dict, rank) + check_all_reduce_bwd(process_group_dict, rank) - flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict - # test all reduce in 1D flatten device mesh - check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) gpc.destroy() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 8350fb3e7fe6..5a1aef79f332 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port): else: raise ValueError(f'rank {rank} is not in the device mesh') - dtensor_from_local = distribute_tensor(original_tensor, new_layout) + dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) if rank == 0: assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index d9dff8af933d..5388fd901e09 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -entire_shape = torch.Size((64, 32, 16)) +global_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec, - entire_shape=entire_shape) + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) @@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_all2all, - entire_shape=entire_shape) + layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) @@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_shard, - entire_shape=entire_shape) + shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) @@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) @@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - # checkout cached_spec_pairs_transform_path + # checkout chached_spec_pairs_transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence @@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) - original_tensor = torch.rand(entire_shape).cuda() + original_tensor = torch.rand(global_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 6fe9ee292cd0..859eef051256 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,9 +1,10 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec + from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index d66d4fec14d1..9bd9805e9b8f 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): # the mesh is in the following topo # [[0, 1], # [2, 3]] - physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) row_id = rank // 2 diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 909c84ef0f0e..5007c4141849 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -5,7 +5,7 @@ def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], From 6db5acc3722de9dc8550aa82237aaad592d27cf5 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 27 Jun 2023 11:06:12 +0800 Subject: [PATCH 07/10] add new_model_class & postprocess --- colossalai/shardformer/policies/vit.py | 38 +++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 03de6aac5de2..0c927c287244 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -80,27 +80,27 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: suffix="output.dropout", target_module=Dropout1D, ), - SubModuleReplacementDescription( - suffix="layernorm_before", - target_module=LayerNorm1D, - ), - SubModuleReplacementDescription( - suffix="layernorm_after", - target_module=LayerNorm1D, - ), - ] - ), - ViTModel: - ModulePolicyDescription( - attribute_replacement{}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="layernorm", - target_module=LayerNorm1D, - ) + # SubModuleReplacementDescription( + # suffix="layernorm_before", + # target_module=LayerNorm1D, + # ), + # SubModuleReplacementDescription( + # suffix="layernorm_after", + # target_module=LayerNorm1D, + # ), ] ), + # ViTModel: + # ModulePolicyDescription( + # attribute_replacement{}, + # param_replacement=[], + # sub_module_replacement=[ + # SubModuleReplacementDescription( + # suffix="layernorm", + # target_module=LayerNorm1D, + # ) + # ] + # ), } From 7e9883a3b65621bb621f22ecf15bf6d34ab2e7e0 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 27 Jun 2023 17:36:44 +0800 Subject: [PATCH 08/10] add vit readme --- colossalai/shardformer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index fee4cce7a28a..da80a7276b68 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer: - [ ] GPT Neo - [ ] GPT-J - [ ] CV - - [ ] ViT + - [x] ViT - [ ] BEiT - [ ] SwinTransformer - [ ] SwinTransformer V2 From 91201297c7ff97ebc39521b74ff996f9bf12d387 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 27 Jun 2023 18:33:15 +0800 Subject: [PATCH 09/10] delete old files & fix the conflict --- colossalai/shardformer/layer/embedding1d.py | 149 ------- colossalai/shardformer/layer/layernorm1d.py | 73 ---- colossalai/shardformer/layer/linear1d.py | 346 ---------------- colossalai/shardformer/layer/linearconv1d.py | 377 ------------------ .../shardformer/layer/parallelmodule.py | 35 -- .../layer/vocabparallelembedding1d.py | 170 -------- colossalai/shardformer/policies/vit.py | 27 +- tests/test_device/test_device_mesh.py | 66 +-- .../test_model/test_shard_bert.py | 19 - .../test_model/test_shard_t5.py | 7 +- 10 files changed, 10 insertions(+), 1259 deletions(-) delete mode 100644 colossalai/shardformer/layer/embedding1d.py delete mode 100644 colossalai/shardformer/layer/layernorm1d.py delete mode 100644 colossalai/shardformer/layer/linear1d.py delete mode 100644 colossalai/shardformer/layer/linearconv1d.py delete mode 100644 colossalai/shardformer/layer/parallelmodule.py delete mode 100644 colossalai/shardformer/layer/vocabparallelembedding1d.py diff --git a/colossalai/shardformer/layer/embedding1d.py b/colossalai/shardformer/layer/embedding1d.py deleted file mode 100644 index 1108d5d6a936..000000000000 --- a/colossalai/shardformer/layer/embedding1d.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Callable, List, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.nn import init as init -from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise -from colossalai.utils.cuda import get_current_device - -from ._operation import gather_forward_split_backward -from .parallelmodule import ParallelModule -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class Embedding1D(ParallelModule): - r"""Embedding for 1D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - self.process_group = process_group - self.num_partitions = dist.get_world_size(process_group) - self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - # self.gather_output = gather_output - - if device is None: - device = get_current_device() - - self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - - @staticmethod - def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": - r""" - Build a 1D parallelized Embedding from a native nn.Embedding module. - """ - # get the attributes - num_embedding = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - max_norm = module.max_norm - norm_type = module.norm_type - scale_grad_by_freq = module.scale_grad_by_freq - sparse = module.sparse - dtype = module.weight.dtype - device = module.weight.device - - # sparse is not support yet - if sparse: - raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") - - embedding = Embedding1D(num_embeddings=num_embedding, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - process_group=process_group, - dtype=dtype, - device=device, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - - return embedding - - def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - - return output diff --git a/colossalai/shardformer/layer/layernorm1d.py b/colossalai/shardformer/layer/layernorm1d.py deleted file mode 100644 index 78bd64cfb504..000000000000 --- a/colossalai/shardformer/layer/layernorm1d.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from collections import OrderedDict - -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.utils.checkpointing import broadcast_state_dict - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class LayerNorm1D(ColossalaiModule): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps).to(dtype) - super().__init__(norm) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/layer/linear1d.py b/colossalai/shardformer/layer/linear1d.py deleted file mode 100644 index d59d32df824e..000000000000 --- a/colossalai/shardformer/layer/linear1d.py +++ /dev/null @@ -1,346 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -from typing import Callable, List, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.nn import init as init -from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.cuda import get_current_device - -from ._operation import ( - gather_forward_split_backward, - linear_with_async_comm, - reduce_input, - split_forward_gather_backward, -) -from .parallelmodule import ParallelModule -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class Linear1D_Col(ParallelModule): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - device (`torch.device`): The device of parameters, defaults to None. - process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (`typing.Callable`): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (`typing.Callable`): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - self.device = device - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - self.out_features_per_partition = divide(out_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native PyTorch linear layer to a parallelized linear layer. - """ - # get the attributes - in_features = module.in_features - out_features = module.out_features - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = Linear1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on row is equal to shard on column - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - if bias: - sharded_bias = shard_colwise(module.bias.data, process_group) - linear_1d.bias.copy_(sharded_bias) - - return linear_1d - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output - - -class Linear1D_Row(ParallelModule): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native PyTorch linear layer to a parallelized linear layer. - """ - # get the attributes - in_features = module.in_features - out_features = module.out_features - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = Linear1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_colwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - - if bias: - linear_1d.bias.copy_(module.bias.data) - - return linear_1d - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - if self.process_group is None: - src_rank = 0 - else: - src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) - - origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) - - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias diff --git a/colossalai/shardformer/layer/linearconv1d.py b/colossalai/shardformer/layer/linearconv1d.py deleted file mode 100644 index 4a5cb0707900..000000000000 --- a/colossalai/shardformer/layer/linearconv1d.py +++ /dev/null @@ -1,377 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -from typing import Callable, List, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.nn import init as init -from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.cuda import get_current_device - -from ._operation import ( - gather_forward_split_backward, - linear_with_async_comm, - reduce_input, - split_forward_gather_backward, -) -from .parallelmodule import ParallelModule -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class LinearConv1D_Col(ParallelModule): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - device (`torch.device`): The device of parameters, defaults to None. - process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (`typing.Callable`): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (`typing.Callable`): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - self.device = device - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - self.out_features_per_partition = divide(out_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, - *args, **kwargs) -> ParallelModule: - r""" - Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. - """ - # get the attributes - in_features = module.weight.shape[0] - out_features = module.weight.shape[1] - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = LinearConv1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on row is equal to shard on column - - # first rearange the order of weight and bias - world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_cast) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1) - rearanged_weight_chunks = [weight_chunks[i] for i in new_order] - rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) - sharded_weight = shard_colwise(rearanged_weight, process_group) - linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) - - if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) - rearanged_bias_chunks = [bias_chunks[i] for i in new_order] - rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) - sharded_bias = shard_colwise(rearanged_bias, process_group) - linear_1d.bias.copy_(sharded_bias.contiguous()) - - return linear_1d - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output - - -class LinearConv1D_Row(ParallelModule): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, - *args, **kwargs) -> ParallelModule: - r""" - Convert a native PyTorch linear layer to a parallelized linear layer. - """ - # get the attributes - in_features = module.weight.shape[0] - out_features = module.weight.shape[1] - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = LinearConv1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - - # first rearange the order of weight and bias - world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_cast) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0) - rearanged_weight_chunks = [weight_chunks[i] for i in new_order] - rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0) - sharded_weight = shard_rowwise(rearanged_weight, process_group) - linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) - - if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) - rearanged_bias_chunks = [bias_chunks[i] for i in new_order] - rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) - linear_1d.bias.copy_(rearanged_bias.contiguous()) - - return linear_1d - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - if self.process_group is None: - src_rank = 0 - else: - src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) - - origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) - - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias diff --git a/colossalai/shardformer/layer/parallelmodule.py b/colossalai/shardformer/layer/parallelmodule.py deleted file mode 100644 index 3d19bbea7e47..000000000000 --- a/colossalai/shardformer/layer/parallelmodule.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from abc import ABC, abstractmethod -from typing import List, Union - -import torch.nn as nn -from torch.distributed import ProcessGroup - -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import init as init - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class ParallelModule(nn.Module, ABC): - - @abstractmethod - def from_native_module(module: nn.Module, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": - """ - Convert a native PyTorch module to a parallelized module. - - Args: - module (nn.Module): the module to be converted. - process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. - If this is a list, the process group at the ith index of the list will correspond to the process group - in the ith axis of the device mesh. Defaults to None, which means the global process group. - """ - pass diff --git a/colossalai/shardformer/layer/vocabparallelembedding1d.py b/colossalai/shardformer/layer/vocabparallelembedding1d.py deleted file mode 100644 index 4c325c68421b..000000000000 --- a/colossalai/shardformer/layer/vocabparallelembedding1d.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from collections import OrderedDict -from typing import Callable, List, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.context import ParallelMode, seed -from colossalai.nn import init as init -from colossalai.nn.layer.base_layer import ParallelLayer -from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_rowwise -from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict - -from ._operation import reduce_input -from .parallelmodule import ParallelModule -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class VocabParallelEmbedding1D(ParallelLayer): - r"""Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - self.process_group = process_group - - tensor_parallel_size = dist.get_world_size(group=process_group) - tensor_parallel_rank = dist.get_rank(group=process_group) - - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings = self.num_embeddings_per_partition - self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - - @staticmethod - def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native pytorch embedding module to a parallel module. - """ - # get the origin attributes - num_embeddings = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - device = module.weight.device - - # ensure only one process group is used - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - # create the parallel module - vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - device=device, - process_group=process_group, - *args, - **kwargs) - with torch.no_grad(): - # shard and slice the weight along the vocabulary(num_embeddings) dimension - # the shape of the weight is (num_embeddings, embedding_dim) - shard_weight = shard_rowwise(module.weight.data, process_group) - vocab_embedding_1d.weight.data.copy_(shard_weight) - - return vocab_embedding_1d - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) - - # Mask the output embedding. - output_parallel[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, self.process_group) - return output diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 0c927c287244..4a2b72057d05 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -4,7 +4,7 @@ from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention -from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, LayerNorm1D, Dropout1D +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -80,28 +80,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: suffix="output.dropout", target_module=Dropout1D, ), - # SubModuleReplacementDescription( - # suffix="layernorm_before", - # target_module=LayerNorm1D, - # ), - # SubModuleReplacementDescription( - # suffix="layernorm_after", - # target_module=LayerNorm1D, - # ), ] ), - # ViTModel: - # ModulePolicyDescription( - # attribute_replacement{}, - # param_replacement=[], - # sub_module_replacement=[ - # SubModuleReplacementDescription( - # suffix="layernorm", - # target_module=LayerNorm1D, - # ) - # ] - # ), } + + def new_model_class(self): + return None + + def postprocess(self): + return self.model diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 43b1f4276e8a..1f8db99c9236 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -84,70 +84,6 @@ def test_device_mesh_from_process_group(): spawn(check_init_from_process_group, 4) -def check_1d_device_mesh(): - # check for 1D device mesh - process_group = dist.GroupMember.WORLD - device_mesh = DeviceMesh.from_process_group(process_group) - - # checks - assert device_mesh.shape == [4] - assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' - assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' - assert device_mesh.is_initialized - assert device_mesh.num_devices == 4 - assert device_mesh.is_initialized - assert device_mesh.logical_mesh_id is None - assert device_mesh._is_init_from_process_group - - -def check_2d_device_mesh(): - # create process group for 2D device mesh - first_row_ranks = [0, 1] - second_row_ranks = [2, 3] - first_col_ranks = [0, 2] - second_col_ranks = [1, 3] - - first_row_pg = dist.new_group(first_row_ranks, backend='nccl') - second_row_pg = dist.new_group(second_row_ranks, backend='nccl') - first_col_pg = dist.new_group(first_col_ranks, backend='nccl') - second_col_pg = dist.new_group(second_col_ranks, backend='nccl') - - # check for - current_rank = dist.get_rank() - - if current_rank in first_row_ranks: - row_pg = first_row_pg - else: - row_pg = second_row_pg - - if current_rank in first_col_ranks: - col_pg = first_col_pg - else: - col_pg = second_col_pg - - device_mesh = DeviceMesh.from_process_group([col_pg, row_pg]) - - # checks - assert device_mesh.shape == [2, 2] - assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' - assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' - assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' - assert device_mesh.num_devices == 4 - assert device_mesh.is_initialized - assert device_mesh.logical_mesh_id is None - assert device_mesh._is_init_from_process_group - - -def check_init_from_process_group(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_device_mesh_from_process_group(): - spawn(check_init_from_process_group, 4) - - if __name__ == '__main__': test_device_mesh() - test_device_mesh_from_process_group() + test_device_mesh_from_process_group() \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index bf828530415d..ad98e3d073d4 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -38,26 +38,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo def check_bert(rank, world_size, port): disable_existing_loggers() -<<<<<<< HEAD colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') -======= - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - forward_list = [ - BertForMaskedLM, - BertForPreTraining, - BertLMHeadModel, - - # TODO: do not work yet - # BertModel, - # BertForSequenceClassification - # BertForNextSentencePrediction, - ] - backward_lsit = [BertForMaskedLM, BertLMHeadModel] - - for model_fn in forward_list: - org_model, sharded_model = build_model(world_size, model_fn) - check_forward(org_model, sharded_model) ->>>>>>> 0cf164a2... update sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index b5e0055801eb..6074a902e9b0 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -45,10 +45,7 @@ def check_t5(rank, world_size, port): org_model, sharded_model = build_model(world_size, model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - for model_fn in model_fn_list: - org_model, sharded_model = build_model(world_size, model_fn) - check_forward_backward(org_model, sharded_model) - torch.cuda.empty_cache() + torch.cuda.empty_cache() @pytest.mark.dist @@ -59,4 +56,4 @@ def test_t5(): if __name__ == "__main__": - test_t5() + test_t5() \ No newline at end of file From 0c93dfa570875fb37b1b347a1d625d9a88ab4d96 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Wed, 28 Jun 2023 11:47:18 +0800 Subject: [PATCH 10/10] fix sth --- colossalai/shardformer/layer/_operation.py | 2 +- colossalai/shardformer/layer/layernorm.py | 2 +- colossalai/shardformer/policies/bert.py | 11 +---------- colossalai/shardformer/policies/t5.py | 2 +- tests/test_shardformer/test_layer/test_layernorm.py | 2 +- 5 files changed, 5 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01b33..c025daaeccc7 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -287,4 +287,4 @@ def reduce_forward(input_, process_group): def reduce_backward(input_, process_group): - return _ReduceBackward.apply(input_, process_group) + return _ReduceBackward.apply(input_, process_group) \ No newline at end of file diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/layernorm.py index 83854239cf90..6103380fe8a5 100644 --- a/colossalai/shardformer/layer/layernorm.py +++ b/colossalai/shardformer/layer/layernorm.py @@ -61,4 +61,4 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: # copy weight and bias layernorm.weight.copy_(module.weight) layernorm.bias.copy_(module.bias) - return layernorm + return layernorm \ No newline at end of file diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 93d004d7b10f..fb70cdff8824 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -288,15 +288,6 @@ def module_policy(self): module_policy.update(addon_module) return module_policy - def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) - setattr_(self.model, v, param) - return self.model - # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): @@ -325,4 +316,4 @@ def module_policy(self): ]) } module_policy.update(addon_module) - return module_policy + return module_policy \ No newline at end of file diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 30433f751088..9a1b63e46d2c 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -167,4 +167,4 @@ def module_policy(self): class T5EncoderPolicy(T5ModelPolicy): - pass + pass \ No newline at end of file diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index a117845545be..080fae034956 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -41,4 +41,4 @@ def test_layernorm(): if __name__ == '__main__': - test_layernorm_1d() + test_layernorm_1d() \ No newline at end of file