From ae2ce5012749e34d8110dfc8e30ce0879357cae7 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 10:55:13 +0800 Subject: [PATCH 01/16] [shardformer] embedding support inplace sharding --- colossalai/shardformer/layer/embedding.py | 66 +++++++++++-------- colossalai/tensor/d_tensor/api.py | 8 +++ .../test_layer/test_embedding.py | 6 +- 3 files changed, 49 insertions(+), 31 deletions(-) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 07341ef73515..a2e5434c9d77 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import torch import torch.distributed as dist @@ -13,7 +13,12 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param +from colossalai.tensor.d_tensor.api import ( + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, + sharded_tensor_to_param, +) from ._operation import gather_forward_split_backward, reduce_forward from .parallel_module import ParallelModule @@ -60,6 +65,7 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = True, + weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -74,18 +80,23 @@ def __init__(self, self.embed_kwargs = kwargs self.gather_output = gather_output - # Parameters. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) - sharded_weight = shard_colwise(weight, process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) + # Parameters. + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + sharded_weight = shard_colwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, @@ -121,14 +132,10 @@ def from_native_module(module: nn.Embedding, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, + weight=module.weight, *args, **kwargs) - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - return embedding def reset_parameters(self, weight_initializer) -> None: @@ -143,7 +150,6 @@ def _fill_padding_idx_with_zero(self) -> None: def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - if self.gather_output: output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) return output @@ -188,6 +194,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -207,16 +214,22 @@ def __init__(self, self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - # parameter - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) - sharded_weight = shard_rowwise(weight, process_group) - self.weight = sharded_tensor_to_param(sharded_weight) - # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - self.reset_parameters(weight_initializer) + + # parameter + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + sharded_weight = shard_rowwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) + + if weight is None: + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -243,15 +256,10 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, padding_idx=padding_idx, device=device, process_group=process_group, + weight=module.weight, *args, **kwargs) - with torch.no_grad(): - # shard and slice the weight along the vocabulary(num_embeddings) dimension - # the shape of the weight is (num_embeddings, embedding_dim) - shard_weight = shard_rowwise(module.weight.data, process_group) - vocab_embedding_1d.weight.data.copy_(shard_weight) - return vocab_embedding_1d def reset_parameters(self, weight_initializer) -> None: diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 95a44e09e16a..e1033b0b39b1 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -235,6 +235,14 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): return param +def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None: + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + param.data = dtensor + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + def compute_global_numel(dtensor: torch.Tensor) -> int: """ Compute the global number of elements in the distributed tensor. diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 99e494359af7..d62dba7ea92a 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -15,11 +15,13 @@ def check_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + embedding = nn.Embedding(32, 128).cuda() with ctx: - embedding = nn.Embedding(32, 128).cuda() - embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) + embedding_copy = nn.Embedding(32, 128).cuda() + embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None) assert embedding_1d.weight.shape == torch.Size([32, 64]) + assert embedding_1d.weight is embedding_copy.weight # ensure state dict is reversibly loadable embedding.load_state_dict(embedding_1d.state_dict()) From 03242bfd6821feaa8bc6e4bb0369bf02e84cb0d8 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 11:02:30 +0800 Subject: [PATCH 02/16] [shardformer] linear support inplace sharding --- colossalai/shardformer/layer/linear.py | 107 +++++++++++------- .../test_layer/test_linear_1d.py | 34 ++++-- 2 files changed, 88 insertions(+), 53 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 383d9b3f533a..461a18d87561 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- import math -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -15,7 +15,7 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_existing_param from ._operation import ( gather_forward_split_backward, @@ -65,6 +65,8 @@ def __init__(self, process_group: ProcessGroup = None, gather_output: bool = False, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -80,26 +82,40 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # Parameters. - factory_kwargs = {'device': device, 'dtype': dtype} + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) - sharded_weight = shard_rowwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + + # Parameters. + if weight is None: + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - sharded_bias = shard_colwise(bias, self.process_group) - self.bias = sharded_tensor_to_param(sharded_bias) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_ = bias_.to(device=device, dtype=dtype) + self.bias = bias_ + sharded_bias = shard_colwise(self.bias.data, self.process_group) + sharded_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -125,17 +141,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - with torch.no_grad(): - # the weight to the linear layer is a transpose - # thus shard on row is equal to shard on column - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - if bias: - sharded_bias = shard_colwise(module.bias.data, process_group) - linear_1d.bias.copy_(sharded_bias) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -198,6 +208,8 @@ def __init__(self, process_group: ProcessGroup = None, parallel_input: bool = True, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1): @@ -216,25 +228,40 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) - sharded_weight = shard_colwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + sharded_weight = shard_colwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if self.stream_chunk_num > 1: # TODO() work for inference only self.chunk_weight() + if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): self.reset_parameters(weight_initializer, bias_initializer) @@ -262,19 +289,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_colwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - - if bias: - linear_1d.bias.copy_(module.bias.data) - return linear_1d def chunk_weight(self): diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index da3cd85ec407..aa75879e0313 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -15,14 +15,16 @@ @parameterize('lazy_init', [False, True]) def check_linear_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - + linear = nn.Linear(32, 128).cuda() with ctx: - linear = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + linear_copy = nn.Linear(32, 128).cuda() + linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) assert is_distributed_tensor(linear_col.bias) + assert linear_copy.weight is linear_col.weight + assert linear_copy.bias is linear_col.bias # ensure the shape is correct assert linear_col.weight.shape == torch.Size([64, 32]) @@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool): def check_linear_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear = nn.Linear(32, 128).cuda() with ctx: - linear = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_copy = nn.Linear(32, 128).cuda() + linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias + + linear.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear.state_dict()) # check computation correctness x = torch.rand(4, 32).cuda() @@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool): def check_linear_col_plus_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() + with ctx: - linear_1 = nn.Linear(32, 128).cuda() - linear_2 = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) + linear_1_copy = nn.Linear(32, 128).cuda() + linear_2_copy = nn.Linear(128, 32).cuda() + linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + + linear_1.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear_1.state_dict()) + linear_2.load_state_dict(linear_row.state_dict()) + linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness x = torch.rand(4, 32).cuda() From aaeb60440060df3795f583a7877cf6f9cb8c57d0 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 11:07:09 +0800 Subject: [PATCH 03/16] [shardformer] layernorm support inplace sharding --- colossalai/shardformer/layer/normalization.py | 10 +++------- tests/test_shardformer/test_layer/test_layernorm.py | 7 +++++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 9bb7738c0f0a..0aea295664a7 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -60,10 +60,8 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) - with torch.no_grad(): - # copy weight and bias - layernorm.weight.copy_(module.weight) - layernorm.bias.copy_(module.bias) + layernorm.weight = module.weight + layernorm.bias = module.bias return layernorm @@ -101,8 +99,6 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) - with torch.no_grad(): - # copy weight and bias - rmsnorm.weight.copy_(module.weight) + rmsnorm.weight = module.weight return rmsnorm diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index 2cb6928edf83..f9c21b82a282 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -14,11 +14,14 @@ def check_layernorm(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + norm = nn.LayerNorm(128, 0.00001).cuda() with ctx: - norm = nn.LayerNorm(128, 0.00001).cuda() - norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) + norm_copy = nn.LayerNorm(128, 0.00001).cuda() + norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None) assert norm1d.weight.shape == torch.Size([128]) + assert norm_copy.weight is norm1d.weight + assert norm_copy.bias is norm1d.bias # ensure state dict is reversibly loadable norm.load_state_dict(norm1d.state_dict()) From 6d2633437966986dc250584f7f34dad5874af8c1 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 11:38:51 +0800 Subject: [PATCH 04/16] [shardformer] qkv support inplace sharding --- colossalai/shardformer/layer/embedding.py | 7 +- .../shardformer/layer/qkv_fused_linear.py | 123 ++++++++++-------- colossalai/tensor/d_tensor/api.py | 12 ++ .../test_layer/test_qkv_fused_linear_1d.py | 15 ++- 4 files changed, 90 insertions(+), 67 deletions(-) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index a2e5434c9d77..e7ca5986f6ad 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -13,12 +13,7 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import ( - shard_colwise, - shard_rowwise, - sharded_tensor_to_existing_param, - sharded_tensor_to_param, -) +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_existing_param from ._operation import gather_forward_split_backward, reduce_forward from .parallel_module import ParallelModule diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index c94d93069e93..50e00588aedf 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -2,12 +2,11 @@ # -*- encoding: utf-8 -*- import math -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn -import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -16,10 +15,10 @@ from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import ( - customized_distributed_tensor_to_param, + customized_distributed_tensor_to_existing_param, distribute_tensor_with_customization, shard_rowwise, - sharded_tensor_to_param, + sharded_tensor_to_existing_param, ) from ._operation import ( @@ -173,6 +172,8 @@ def __init__(self, gather_output: bool = False, skip_bias_add: bool = False, n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -190,10 +191,24 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight def shard_fn(tensor): return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) @@ -202,24 +217,24 @@ def gather_fn(tensor): return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) with torch.no_grad(): - sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) - self.weight = customized_distributed_tensor_to_param(sharded_weight) + sharded_weight = distribute_tensor_with_customization(self.weight, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) if bias: - bias = torch.empty(self.out_features, **factory_kwargs) - + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_ = bias_.to(device=device, dtype=dtype) + self.bias = bias_ with torch.no_grad(): - sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) - self.bias = customized_distributed_tensor_to_param(sharded_bias) + sharded_bias = distribute_tensor_with_customization(self.bias, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, @@ -250,24 +265,11 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=True) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, - n_fused=n_fused, - process_group=process_group, - is_transposed=True) - linear_1d.bias.data.copy_(sharded_bias.data) - return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -333,6 +335,8 @@ def __init__(self, process_group: ProcessGroup = None, parallel_input: bool = True, skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1): @@ -351,30 +355,45 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + # Divide the weight matrix along the last dimension. self.input_size_per_partition = divide(in_features, self.num_partitions) + # sanity check + if weight is not None: + assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + else: + assert bias_ is None, 'bias_ must be None if weight is None' + # Parameters. - # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) - sharded_weight = shard_rowwise(weight, self.process_group) - self.weight = sharded_tensor_to_param(sharded_weight) + if weight is None: + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + self.weight = weight + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if self.stream_chunk_num > 1: # TODO() work for inference only self.chunk_weight() if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + if bias_ is None: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + self.bias = bias_ else: self.bias = None - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - # init weights - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -400,19 +419,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis bias=bias, device=device, process_group=process_group, + weight=module.weight, + bias_=module.bias, *args, **kwargs) - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight.data) - - if bias: - linear_1d.bias.copy_(module.bias.data) - return linear_1d def chunk_weight(self): diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index e1033b0b39b1..32182faf6981 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -440,3 +440,15 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: param.gather_fn = dtensor.gather_fn _hijack_detach_and_clone_for_customized_distributed_tensor(param) return param + + +def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter): + """ + Convert the given customized distributed tensor to an existing parameter. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + + param.data = dtensor.data + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 186b1e8212cc..a410a3c1425c 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -56,10 +56,10 @@ def rearrange(tensor: torch.Tensor, dim: int): @parameterize('lazy_init', [False, True]) def check_linear_conv_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - + linear = Conv1D(192, 48).cuda() with ctx: - linear = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + linear_copy = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, n_fused=3) @@ -91,14 +91,19 @@ def check_linear_conv_1d_col(lazy_init: bool): def check_linear_conv_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + linear = Conv1D(192, 48).cuda() with ctx: - linear = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_copy = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) + # ensure weights are reversibly loadable + linear_row.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_row.state_dict()) + # check computation correctness x = torch.rand(4, 48).cuda() out = linear(x) From a3f22c0f28fbbd272c621508920d3215117dee48 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 11:59:48 +0800 Subject: [PATCH 05/16] [test] update shardformer layer test --- .../test_layer/test_qkv_fused_linear_1d.py | 4 ++++ .../test_layer/test_vocab_parallel_embedding_1d.py | 9 +++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index a410a3c1425c..b45cd172c3ca 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -68,6 +68,8 @@ def check_linear_conv_1d_col(lazy_init: bool): assert linear.bias.shape == torch.Size([192]) assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_conv_col.bias.shape == torch.Size([96]) + assert linear_copy.weight is linear_conv_col.weight + assert linear_copy.bias is linear_conv_col.bias # ensure weights are reversibly loadable linear_conv_col.load_state_dict(linear.state_dict()) @@ -99,6 +101,8 @@ def check_linear_conv_1d_row(lazy_init: bool): assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) + assert linear_copy.weight is linear_row.weight + assert linear_copy.bias is linear_row.bias # ensure weights are reversibly loadable linear_row.load_state_dict(linear.state_dict()) diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index bf5803496f03..6d2f087302d9 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -7,8 +7,7 @@ import colossalai from colossalai.lazy import LazyInitContext -from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D -from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.shardformer.layer import VocabParallelEmbedding1D from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -16,13 +15,15 @@ def check_vocab_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() + embedding = nn.Embedding(128, 32).to('cuda') with ctx: - embedding = nn.Embedding(128, 32).to('cuda') - dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) + embedding_copy = nn.Embedding(128, 32).to('cuda') + dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.num_embeddings == 64 assert dist_embedding_1d.embedding_dim == 32 + assert embedding_copy.weight is dist_embedding_1d.weight # ensure state dict is reversibly loadable embedding.load_state_dict(dist_embedding_1d.state_dict()) From b538bd894f49e9e78f771ad85521a42a9858df71 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 14:49:04 +0800 Subject: [PATCH 06/16] [shardformer] fix shared param sharding --- colossalai/shardformer/layer/embedding.py | 17 +++++++---- colossalai/shardformer/layer/linear.py | 29 ++++++++++++------- .../shardformer/layer/qkv_fused_linear.py | 22 ++++++++------ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index e7ca5986f6ad..09b22abb17cc 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -13,7 +13,12 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_existing_param +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) from ._operation import gather_forward_split_backward, reduce_forward from .parallel_module import ParallelModule @@ -86,8 +91,9 @@ def __init__(self, else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight - sharded_weight = shard_colwise(self.weight.data, process_group) - sharded_tensor_to_existing_param(sharded_weight, self.weight) + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if weight is None: with self.randomizer.fork_rng(enable_cpu=True): @@ -220,8 +226,9 @@ def __init__(self, else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight - sharded_weight = shard_rowwise(self.weight.data, process_group) - sharded_tensor_to_existing_param(sharded_weight, self.weight) + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if weight is None: self.reset_parameters(weight_initializer) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 461a18d87561..bb36854bd772 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -15,7 +15,12 @@ from colossalai.lazy import LazyInitContext from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_existing_param +from colossalai.tensor.d_tensor.api import ( + is_distributed_tensor, + shard_colwise, + shard_rowwise, + sharded_tensor_to_existing_param, +) from ._operation import ( gather_forward_split_backward, @@ -99,17 +104,19 @@ def __init__(self, else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight - sharded_weight = shard_rowwise(self.weight.data, self.process_group) - sharded_tensor_to_existing_param(sharded_weight, self.weight) + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if bias: if bias_ is None: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) else: - bias_ = bias_.to(device=device, dtype=dtype) + bias_.data = bias_.data.to(device=device, dtype=dtype) self.bias = bias_ - sharded_bias = shard_colwise(self.bias.data, self.process_group) - sharded_tensor_to_existing_param(sharded_bias, self.bias) + if not is_distributed_tensor(self.bias): + sharded_bias = shard_colwise(self.bias.data, self.process_group) + sharded_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None @@ -246,8 +253,9 @@ def __init__(self, else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight - sharded_weight = shard_colwise(self.weight.data, self.process_group) - sharded_tensor_to_existing_param(sharded_weight, self.weight) + if not is_distributed_tensor(self.weight): + sharded_weight = shard_colwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if self.stream_chunk_num > 1: # TODO() work for inference only @@ -262,8 +270,9 @@ def __init__(self, else: self.bias = None - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + if weight is None: + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 50e00588aedf..d6d7d27292e1 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -17,6 +17,7 @@ from colossalai.tensor.d_tensor.api import ( customized_distributed_tensor_to_existing_param, distribute_tensor_with_customization, + is_distributed_tensor, shard_rowwise, sharded_tensor_to_existing_param, ) @@ -216,19 +217,21 @@ def shard_fn(tensor): def gather_fn(tensor): return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) - with torch.no_grad(): - sharded_weight = distribute_tensor_with_customization(self.weight, shard_fn, gather_fn) - customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) + if not is_distributed_tensor(self.weight): + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(self.weight, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) if bias: if bias_ is None: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) else: - bias_ = bias_.to(device=device, dtype=dtype) + bias_.data = bias_.data.to(device=device, dtype=dtype) self.bias = bias_ - with torch.no_grad(): - sharded_bias = distribute_tensor_with_customization(self.bias, shard_fn, gather_fn) - customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) + if not is_distributed_tensor(self.bias): + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(self.bias, shard_fn, gather_fn) + customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None @@ -376,8 +379,9 @@ def __init__(self, else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight - sharded_weight = shard_rowwise(self.weight.data, self.process_group) - sharded_tensor_to_existing_param(sharded_weight, self.weight) + if not is_distributed_tensor(self.weight): + sharded_weight = shard_rowwise(self.weight.data, self.process_group) + sharded_tensor_to_existing_param(sharded_weight, self.weight) if self.stream_chunk_num > 1: # TODO() work for inference only From 4f0818557ce00b0a84a4b53e784eb29b7bccc234 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 15:21:11 +0800 Subject: [PATCH 07/16] [shardformer] fix bert policy --- colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/policies/bert.py | 127 +++++++++--------- tests/test_shardformer/test_model/_utils.py | 14 ++ .../test_model/test_shard_bert.py | 3 +- 4 files changed, 81 insertions(+), 66 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7fad4948dfd0..7cdcfc31811f 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,10 +3,11 @@ from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm +from .parallel_module import ParallelModule from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm' + 'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule' ] diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 1af26f50484c..89800cafeb6e 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -80,44 +80,44 @@ def module_policy(self): "crossattention.self.num_attention_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.self.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( @@ -143,8 +143,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ) ], - policy=policy, - target_key=BertLayer) + policy=policy, + target_key=BertLayer) # handle embedding layer self.append_or_create_submodule_replacement( description=[SubModuleReplacementDescription( @@ -163,8 +163,8 @@ def add_lm_head_policy(self, base_policy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - policy=base_policy, - target_key=BertLMPredictionHead) + policy=base_policy, + target_key=BertLMPredictionHead) # optimize with fused normalization if self.shard_config.enable_fused_normalization: @@ -173,8 +173,19 @@ def add_lm_head_policy(self, base_policy): suffix="transform.LayerNorm", target_module=col_nn.FusedLayerNorm, ), - policy=base_policy, - target_key=BertLMPredictionHead) + policy=base_policy, + target_key=BertLMPredictionHead) + return base_policy + + def add_lm_prediction_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + method_replacement = { + '_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict, + '_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict, + } + self.append_or_create_method_replacement(description=method_replacement, + policy=base_policy, + target_key=BertLMPredictionHead) return base_policy def postprocess(self): @@ -240,6 +251,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForPreTraining self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy) return policy @@ -266,21 +278,13 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: model = self.model if self.pipeline_stage_manager: if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: model.bert.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): @@ -291,6 +295,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertLMHeadModel self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy) return policy @@ -316,21 +321,13 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert if self.pipeline_stage_manager: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: bert_model.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): @@ -341,6 +338,7 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) + mpolicy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy) return policy @@ -366,7 +364,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: bert_model = self.model.bert if self.pipeline_stage_manager: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): - #tie weights + # tie weights return [{ 0: bert_model.embeddings.word_embeddings.weight, self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight @@ -1032,6 +1030,7 @@ def bert_for_masked_lm_forward( stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, ): + # -> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., @@ -1109,7 +1108,7 @@ def bert_for_next_sentence_prediction_forward( stage_index: Optional[List[int]] = None, **kwargs, ): - #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 825d6df6bb5e..48436ceaeb29 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,6 +1,9 @@ import copy from contextlib import nullcontext +import torch +from torch.nn import Module + from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -61,3 +64,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) return org_output, org_loss, shard_output, shard_loss + + +def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): + org_sd = org_model.state_dict() + shard_sd = sharded_model.state_dict() + for k, v in org_sd.items(): + assert k in shard_sd, f'{name} {k} not in sharded model' + shard_v = shard_sd[k] + assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}' + assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}' + assert torch.equal(v, shard_v), f'{name} {k} value mismatch' diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 7f179acd7356..ea0f122644dc 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -12,7 +12,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -75,6 +75,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 718ac414af374eecaf5cb7e04e9e778ce1c4adaf Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 15:25:17 +0800 Subject: [PATCH 08/16] [shardformer] fix bloom policy --- colossalai/shardformer/policies/bert.py | 4 +- colossalai/shardformer/policies/bloom.py | 79 ++++++++----------- .../test_model/test_shard_bloom.py | 3 +- 3 files changed, 36 insertions(+), 50 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 89800cafeb6e..b4e71eff7b03 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -3,7 +3,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn from torch import Tensor from torch.nn import CrossEntropyLoss, Module from transformers.modeling_outputs import ( @@ -28,12 +27,11 @@ BertLMHeadModel, BertModel, ) -from transformers.utils import ModelOutput, logging +from transformers.utils import logging import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription logger = logging.get_logger(__name__) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 8afaadefb696..87555dfb3f61 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -3,9 +3,7 @@ from types import MethodType from typing import Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch -import torch.nn as nn from torch import Tensor from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss from transformers.modeling_outputs import ( @@ -27,7 +25,6 @@ import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from ..modeling.bloom import build_bloom_alibi_tensor_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -63,28 +60,28 @@ def module_policy(self): "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=col_nn.Linear1D_Row, - ), - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ), + ]) policy[BloomModel] = ModulePolicyDescription( attribute_replacement={ @@ -113,8 +110,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ) ], - policy=policy, - target_key=BloomModel) + policy=policy, + target_key=BloomModel) # handle bloom block self.append_or_create_submodule_replacement(description=[ @@ -127,8 +124,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ) ], - policy=policy, - target_key=BloomBlock) + policy=policy, + target_key=BloomBlock) return policy @@ -200,8 +197,8 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForCausalLM) + policy=policy, + target_key=BloomForCausalLM) self.set_pipeline_forward(model_cls=BloomForCausalLM, new_forward=bloom_for_causal_lm_forward, policy=policy) return policy @@ -233,16 +230,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} - - for k, v in binding_map.items(): - param = getattr_(self.model, k) - # tie weights - setattr_(self.model, v, param) - return self.model - class BloomForSequenceClassificationPolicy(BloomPolicy): @@ -254,8 +241,8 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForSequenceClassification) + policy=policy, + target_key=BloomForSequenceClassification) self.set_pipeline_forward(model_cls=BloomForSequenceClassification, new_forward=bloom_for_sequence_classification_forward, policy=policy) @@ -299,8 +286,8 @@ def module_policy(self): target_module=col_nn.DropoutForReplicatedInput, ), ], - policy=policy, - target_key=BloomForTokenClassification) + policy=policy, + target_key=BloomForTokenClassification) self.set_pipeline_forward(model_cls=BloomForTokenClassification, new_forward=bloom_for_token_classification_forward, @@ -692,7 +679,7 @@ def bloom_for_sequence_classification_forward( all_cross_attentions = None if stage_manager.is_last_stage(): batch_size = hidden_states.shape[0] - #update batch size + # update batch size hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index e18168292df5..fe4686aeb979 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -12,7 +12,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -75,6 +75,7 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 9718fa488f47e01be1f65edabdc95a1639d7ecea Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 15:34:20 +0800 Subject: [PATCH 09/16] [shardformer] fix llama policy --- colossalai/shardformer/policies/llama.py | 9 ++------- tests/test_shardformer/test_model/test_shard_llama.py | 3 ++- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b3757452c314..c7cd8182a4ca 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,7 +1,5 @@ -import math from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import torch import torch.nn as nn @@ -9,14 +7,11 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel -from transformers.utils import ModelOutput, logging +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 4d63a43489a3..aaeef13ef873 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -14,7 +14,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -78,6 +78,7 @@ def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_la for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 93e780c6bee452ee99103f087a358ef1032069fa Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 15:36:58 +0800 Subject: [PATCH 10/16] [shardformer] fix opt policy --- colossalai/shardformer/policies/opt.py | 14 -------------- .../test_shardformer/test_model/test_shard_opt.py | 3 ++- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 1435805d2846..bbcc90e00157 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,6 +1,5 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -116,19 +115,6 @@ def module_policy(self): target_key=OPTForCausalLM) return policy - def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = { - 'model.decoder.embed_tokens': 'lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - - return self.model - class OPTForSequenceClassificationPolicy(OPTPolicy): diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index c008596fe2b6..297affceb68a 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -15,7 +15,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -77,6 +77,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 55b5da148d9098a111965b716a727b65ea8f3359 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 14 Jul 2023 15:59:16 +0800 Subject: [PATCH 11/16] [shardformer] fix t5 policy --- colossalai/shardformer/policies/t5.py | 32 +------------------ .../test_model/test_shard_t5.py | 3 +- 2 files changed, 3 insertions(+), 32 deletions(-) diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 37864885b4cc..6b8f404f1769 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -8,7 +8,6 @@ ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] @@ -53,7 +52,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="embed_tokens", - target_module=Embedding1D, + target_module=VocabParallelEmbedding1D, ) ]) policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ @@ -165,12 +164,6 @@ def module_policy(self): return policy def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] - - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) return self.model @@ -211,18 +204,6 @@ def module_policy(self): target_key=T5ForConditionalGeneration) return policy - def postprocess(self): - super().postprocess() - if self.shard_config.enable_tensor_parallelism: - binding_map = {"shared": "lm_head"} - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - - return self.model - class T5EncoderPolicy(T5BasePolicy): @@ -239,14 +220,3 @@ def module_policy(self): policy=base_policy, target_key=T5EncoderModel) return base_policy - - def postprocess(self): - if self.shard_config.enable_tensor_parallelism: - binding_map = [ - ["shared", "encoder.embed_tokens"], - ] - - for k, v in binding_map: - mod = getattr_(self.model, k) - setattr_(self.model, v, mod) - return self.model diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index ccd7d3787d3d..96dfdeb73827 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -14,7 +14,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -88,6 +88,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 859f0c2eeb278ece755cf8fb8c65c5fa50816b5e Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Jul 2023 17:34:39 +0800 Subject: [PATCH 12/16] [shardformer] fix fused qkv linear --- .../shardformer/layer/qkv_fused_linear.py | 15 +-- colossalai/shardformer/policies/gpt2.py | 99 ++++++++----------- tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_shard_gpt2.py | 3 +- 4 files changed, 50 insertions(+), 68 deletions(-) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index d6d7d27292e1..bcefcf058ce0 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -17,6 +17,7 @@ from colossalai.tensor.d_tensor.api import ( customized_distributed_tensor_to_existing_param, distribute_tensor_with_customization, + is_customized_distributed_tensor, is_distributed_tensor, shard_rowwise, sharded_tensor_to_existing_param, @@ -215,11 +216,11 @@ def shard_fn(tensor): return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) def gather_fn(tensor): - return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) + return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) - if not is_distributed_tensor(self.weight): + if not is_customized_distributed_tensor(self.weight): with torch.no_grad(): - sharded_weight = distribute_tensor_with_customization(self.weight, shard_fn, gather_fn) + sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn) customized_distributed_tensor_to_existing_param(sharded_weight, self.weight) if bias: @@ -228,9 +229,9 @@ def gather_fn(tensor): else: bias_.data = bias_.data.to(device=device, dtype=dtype) self.bias = bias_ - if not is_distributed_tensor(self.bias): + if not is_customized_distributed_tensor(self.bias): with torch.no_grad(): - sharded_bias = distribute_tensor_with_customization(self.bias, shard_fn, gather_fn) + sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn) customized_distributed_tensor_to_existing_param(sharded_bias, self.bias) else: self.bias = None @@ -240,8 +241,8 @@ def gather_fn(tensor): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, - *args, **kwargs) -> ParallelModule: + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5d6f47636587..0dfba71e4e5c 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -10,7 +10,6 @@ import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -58,42 +57,42 @@ def module_policy(self): "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -101,8 +100,8 @@ def module_policy(self): suffix="ln_f", target_module=col_nn.FusedLayerNorm, ), - policy=policy, - target_key=GPT2Model) + policy=policy, + target_key=GPT2Model) self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( @@ -117,8 +116,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True) ], - policy=policy, - target_key=GPT2Block) + policy=policy, + target_key=GPT2Block) return policy def postprocess(self): @@ -229,15 +228,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: else: return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism \ - and self.pipeline_stage_manager is None: - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): @@ -288,15 +278,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: else: return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism \ - and self.pipeline_stage_manager is None: - binding_map = {"transformer.wte.weight": "lm_head.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # GPT2ForTokenClassification class GPT2ForTokenClassificationPolicy(GPT2Policy): diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 48436ceaeb29..2320c725d444 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -5,7 +5,6 @@ from torch.nn import Module from colossalai.lazy import LazyInitContext -from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 552c6e2f4d53..99451b403eb7 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -12,7 +12,7 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, run_forward +from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): @@ -77,6 +77,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) + check_state_dict(org_model, sharded_model, name=name) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() From 2ae2118545f5276dfa09ebeaf9defe578edb0db4 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Jul 2023 18:21:59 +0800 Subject: [PATCH 13/16] [shardformer] fix bugs --- colossalai/shardformer/policies/bert.py | 101 +++++++++++------------ colossalai/shardformer/policies/bloom.py | 68 +++++++-------- colossalai/shardformer/shard/sharder.py | 4 +- 3 files changed, 83 insertions(+), 90 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index b4e71eff7b03..0a1a466210b2 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,12 +1,11 @@ from functools import partial -from types import MethodType -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional import torch +import torch.nn as nn from torch import Tensor from torch.nn import CrossEntropyLoss, Module from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MultipleChoiceModelOutput, @@ -78,44 +77,44 @@ def module_policy(self): "crossattention.self.num_attention_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.self.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( @@ -141,8 +140,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ) ], - policy=policy, - target_key=BertLayer) + policy=policy, + target_key=BertLayer) # handle embedding layer self.append_or_create_submodule_replacement( description=[SubModuleReplacementDescription( @@ -161,8 +160,8 @@ def add_lm_head_policy(self, base_policy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - policy=base_policy, - target_key=BertLMPredictionHead) + policy=base_policy, + target_key=BertLMPredictionHead) # optimize with fused normalization if self.shard_config.enable_fused_normalization: @@ -171,8 +170,8 @@ def add_lm_head_policy(self, base_policy): suffix="transform.LayerNorm", target_module=col_nn.FusedLayerNorm, ), - policy=base_policy, - target_key=BertLMPredictionHead) + policy=base_policy, + target_key=BertLMPredictionHead) return base_policy def add_lm_prediction_policy(self, base_policy): @@ -369,14 +368,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: }] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) - return self.model - # BertForSequenceClassification class BertForSequenceClassificationPolicy(BertPolicy): diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 87555dfb3f61..b0e45452964e 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,9 +1,9 @@ import warnings from functools import partial -from types import MethodType from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn from torch import Tensor from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss from transformers.modeling_outputs import ( @@ -60,28 +60,28 @@ def module_policy(self): "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=col_nn.Linear1D_Row, - ), - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ), + ]) policy[BloomModel] = ModulePolicyDescription( attribute_replacement={ @@ -110,8 +110,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ) ], - policy=policy, - target_key=BloomModel) + policy=policy, + target_key=BloomModel) # handle bloom block self.append_or_create_submodule_replacement(description=[ @@ -124,8 +124,8 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ) ], - policy=policy, - target_key=BloomBlock) + policy=policy, + target_key=BloomBlock) return policy @@ -197,8 +197,8 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForCausalLM) + policy=policy, + target_key=BloomForCausalLM) self.set_pipeline_forward(model_cls=BloomForCausalLM, new_forward=bloom_for_causal_lm_forward, policy=policy) return policy @@ -226,7 +226,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # tie weights return [{ 0: bloom_model.transformer.word_embeddings.weight, - self.stage_manager.num_stages - 1: bloom_model.lm_head.weight + self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight }] return [] @@ -241,8 +241,8 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForSequenceClassification) + policy=policy, + target_key=BloomForSequenceClassification) self.set_pipeline_forward(model_cls=BloomForSequenceClassification, new_forward=bloom_for_sequence_classification_forward, policy=policy) @@ -286,8 +286,8 @@ def module_policy(self): target_module=col_nn.DropoutForReplicatedInput, ), ], - policy=policy, - target_key=BloomForTokenClassification) + policy=policy, + target_key=BloomForTokenClassification) self.set_pipeline_forward(model_cls=BloomForTokenClassification, new_forward=bloom_for_token_classification_forward, diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 5e0b572e259c..b32c285bdaab 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -37,11 +37,13 @@ def shard(self) -> List[Dict[int, Tensor]]: self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) self._preprocess() + # get shared params before release unheld layers, this avoid misjudgement of shared params (None is None) + shared_params = self.policy.get_shared_params() self._release_unheld_layers() self._replace_module() self._materialize() self._postprocess() - return self.policy.get_shared_params() + return shared_params def _preprocess(self) -> None: self.model = self.policy.preprocess() From e6c7dc62cbbaf3a43543c5444a328f1679c3357a Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 18 Jul 2023 10:38:47 +0800 Subject: [PATCH 14/16] force sync From 13e2d23db8a7578f15bee5492cb8b5f63b68cd11 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 18 Jul 2023 14:30:18 +0800 Subject: [PATCH 15/16] [test] fix bugs --- requirements/requirements-test.txt | 1 + tests/kit/model_zoo/torchrec/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 50121a9283f2..5de4648070f1 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -16,3 +16,4 @@ triton==2.0.0.dev20221202 git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece +datasets diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * From 049c58114e5a3cd052cdadd7677fbd00629b9646 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 19 Jul 2023 14:20:48 +0800 Subject: [PATCH 16/16] [test] fix transformer version --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 5de4648070f1..6f8a72e3962f 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -4,7 +4,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers +transformers==4.30.2 timm titans torchaudio