From 2c2c3cd994a8b4bc78d3216f47280b6c8f2733d3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 7 Mar 2024 10:35:27 +0800 Subject: [PATCH 01/12] fix --- applications/Colossal-LLaMA-2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 2e4bab75a085..d97da61e4dc8 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -56,6 +56,7 @@ def format_numel_str(numel: int) -> str: def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor From f3f454ed1475aa4f9761f7b5d900ec33e7f99567 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 7 Mar 2024 11:57:09 +0800 Subject: [PATCH 02/12] padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix --- applications/Colossal-LLaMA-2/train.py | 1 - colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 ++++ colossalai/shardformer/modeling/gpt2.py | 9 +++++---- colossalai/shardformer/modeling/llama.py | 5 +++-- colossalai/shardformer/policies/gpt2.py | 7 +++++++ colossalai/shardformer/shard/shard_config.py | 2 ++ 6 files changed, 21 insertions(+), 7 deletions(-) diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index d97da61e4dc8..2e4bab75a085 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -56,7 +56,6 @@ def format_numel_str(numel: int) -> str: def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index da67e6b41fbf..ed339bf1c5ff 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -937,6 +937,7 @@ def __init__( enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, enable_sequence_overlap: bool = False, + parallel_output: bool = True, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, initial_scale: float = 2**16, @@ -961,6 +962,7 @@ def __init__( pp_style: str = "1f1b", num_model_chunks: int = 1, enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 128, ) -> None: super().__init__() assert ( @@ -1033,6 +1035,8 @@ def __init__( enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 055e3096d794..275cba64466a 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -783,11 +783,12 @@ def forward( scale = scale * (1 / float(self.layer_idx + 1)) # use coloattention - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale - ) + if not hasattr(self, "attention"): + self.attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) - attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e10a7ed7da0c..7815d5c6e074 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -481,8 +481,9 @@ def forward( flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( + if not hasattr(self, "attention"): + self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) + attn_output = self.attention( query_states, key_states, value_states, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 022e6ff5b32c..247ae0d7c564 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -33,6 +33,13 @@ def preprocess(self): if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) + elif self.shard_config.pipeline_stage_manager is not None: + # padding vocab_size when using pipeline parallellism + new_vocab_size = vocab_size + multiple = self.shard_config.make_vocab_size_divisible_by + while (new_vocab_size % multiple) != 0: + new_vocab_size += 1 + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b5c9e66e0b87..49ea382f5cfc 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -34,6 +34,8 @@ class ShardConfig: enable_all_optimization: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False + parallel_output: bool = True + make_vocab_size_divisible_by: int = 128 extra_kwargs: Dict[str, Any] = field(default_factory=dict) # pipeline_parallel_size: int # data_parallel_size: int From f2af8fa621f214a44d88d6a4d90a59a89aa46093 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 7 Mar 2024 15:40:58 +0800 Subject: [PATCH 03/12] fix --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 2 ++ colossalai/shardformer/policies/gpt2.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ed339bf1c5ff..3722803d9a58 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -897,6 +897,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. microbatch_size (int, optional): Microbatch size when using pipeline parallelism. Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. @@ -923,6 +924,7 @@ class HybridParallelPlugin(PipelinePluginBase): pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (bool, optional): make the vocabulary size is divisible by `make_vocab_size_divisible_by`, to select a faster CUDA kernel operator. Default to 128. """ def __init__( diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 247ae0d7c564..4a8d4edb7395 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -33,8 +33,8 @@ def preprocess(self): if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) - elif self.shard_config.pipeline_stage_manager is not None: - # padding vocab_size when using pipeline parallellism + else: + # Make vocab_size divisible by `make_vocab_size_divisible_by` to select a faster CUDA kernel operator. new_vocab_size = vocab_size multiple = self.shard_config.make_vocab_size_divisible_by while (new_vocab_size % multiple) != 0: From 349c818e241714521b523dac16c985ac2508d721 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 7 Mar 2024 16:37:26 +0800 Subject: [PATCH 04/12] fix fix fix --- colossalai/shardformer/modeling/gpt2.py | 7 +++++++ colossalai/shardformer/policies/gpt2.py | 7 +++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 06f07ee2b790..fa7e39ff4614 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,6 +25,7 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d +from ..layer._operation import _gather class GPT2PipelineForwards: @@ -337,6 +338,9 @@ def gpt2_lmhead_model_forward( else: loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + lm_logits = _gather(lm_logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1084,6 +1088,9 @@ def forward( else: loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + lm_logits = _gather(lm_logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 27d4f7acfc63..256f8035bf21 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -40,11 +40,10 @@ def preprocess(self): self.model.resize_token_embeddings(new_vocab_size) else: # Make vocab_size divisible by `make_vocab_size_divisible_by` to select a faster CUDA kernel operator. - new_vocab_size = vocab_size multiple = self.shard_config.make_vocab_size_divisible_by - while (new_vocab_size % multiple) != 0: - new_vocab_size += 1 - self.model.resize_token_embeddings(new_vocab_size) + if vocab_size % multiple != 0: + new_vocab_size = (vocab_size // multiple + 1) * multiple + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): From 21fd30413f49a02dfb6e11a054f710b6625e0910 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Sun, 10 Mar 2024 09:34:00 +0800 Subject: [PATCH 05/12] fix gather output --- colossalai/shardformer/modeling/gpt2.py | 6 +++--- colossalai/shardformer/modeling/llama.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index fa7e39ff4614..1e22d9094eae 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,7 +25,7 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import _gather +from ..layer._operation import gather_forward_split_backward class GPT2PipelineForwards: @@ -339,7 +339,7 @@ def gpt2_lmhead_model_forward( loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - lm_logits = _gather(lm_logits, -1, shard_config.tensor_parallel_process_group) + lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1089,7 +1089,7 @@ def forward( loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - lm_logits = _gather(lm_logits, -1, shard_config.tensor_parallel_process_group) + lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 8baffa0dcf90..eb8e9f748527 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,7 +16,7 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import _gather +from ..layer._operation import gather_forward_split_backward try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -290,7 +290,7 @@ def llama_for_causal_lm_forward( loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (logits,) + outputs[1:] @@ -594,7 +594,7 @@ def forward( loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (logits,) + outputs[1:] From 653aa060408b17b8b028bc90b583e798d10eb5c2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 14 Mar 2024 07:27:49 +0800 Subject: [PATCH 06/12] fix --- colossalai/shardformer/policies/gpt2.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 256f8035bf21..9d3ebed2540c 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -32,18 +32,14 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ + vocab_size = self.model.config.vocab_size + multiple = self.shard_config.make_vocab_size_divisible_by if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - else: - # Make vocab_size divisible by `make_vocab_size_divisible_by` to select a faster CUDA kernel operator. - multiple = self.shard_config.make_vocab_size_divisible_by - if vocab_size % multiple != 0: - new_vocab_size = (vocab_size // multiple + 1) * multiple - self.model.resize_token_embeddings(new_vocab_size) + multiple = multiple * world_size + if vocab_size % multiple != 0: + new_vocab_size = vocab_size + multiple - vocab_size % multiple + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): From a980e706dddec5bf258591620eaf8113e8a28ea3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 14 Mar 2024 09:23:23 +0800 Subject: [PATCH 07/12] fix --- .../booster/plugin/hybrid_parallel_plugin.py | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a009f4c6da12..ed42694a166a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -189,7 +189,7 @@ def unwrap(self): return module -def get_param_info(optim: Optimizer): +def get_param_info(optim: Optimizer, model: torch.nn.Module): # Get a backup of necessary information of parameters for future use, which includes: # 1. A complete param_group, with params in the form of param_id # 2. A mapping from param address (obtained using id(param)) to integer param_id @@ -199,7 +199,14 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} + param_info = { + "param_groups": [], + "param2id": {}, + "id2param": {}, + "param2shape": {}, + "old_input_embedding_param_id": None, + "old_output_embedding_param_id": None, + } start_index = 0 for group in optim.param_groups: packed_group = {k: v for k, v in group.items() if k != "params"} @@ -215,6 +222,13 @@ def get_param_info(optim: Optimizer): param_info["param_groups"].append(packed_group) start_index += len(group["params"]) + input_embedding = model.get_input_embeddings() + if input_embedding is not None: + param_info["old_input_embedding_param_id"] = id(input_embedding.weight) + output_embedding = model.get_output_embeddings() + if output_embedding is not None: + param_info["old_output_embedding_param_id"] = id(output_embedding.weight) + return param_info @@ -1067,7 +1081,7 @@ def __init__( overlap_communication=overlap_communication, cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), - forced_dtype=PRECISION_TORCH_TYPE[precision], + # forced_dtype=PRECISION_TORCH_TYPE[precision], ) self.max_norm = max_norm @@ -1076,6 +1090,32 @@ def __del__(self): """Destroy the process groups in ProcessGroupMesh""" self.pg_mesh.destroy_mesh_process_groups() + def set_resized_embedding_to_optimizer(self, model, optimizer, param_info): + old_input_embedding_param_id = param_info["old_input_embedding_param_id"] + if old_input_embedding_param_id is not None: + for param_group in optimizer.param_groups: + group_params = param_group["params"] + new_params = [] + for param in group_params: + if id(param) == old_input_embedding_param_id: + new_input_embeddings = model.module.get_input_embeddings() + new_params.append(new_input_embeddings.weight) + else: + new_params.append(param) + param_group["params"] = new_params + old_output_embedding_param_id = param_info["old_output_embedding_param_id"] + if old_output_embedding_param_id is not None: + for param_group in optimizer.param_groups: + group_params = param_group["params"] + new_params = [] + for param in group_params: + if id(param) == old_output_embedding_param_id: + new_output_embeddings = model.module.get_output_embeddings() + new_params.append(new_output_embeddings.weight) + else: + new_params.append(param) + param_group["params"] = new_params + @property def enable_pipeline_parallelism(self) -> bool: return self.pp_size > 1 @@ -1106,7 +1146,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - param_info = get_param_info(optimizer) + param_info = get_param_info(optimizer, model) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule( @@ -1119,6 +1159,8 @@ def configure( ddp_config=self.ddp_config, custom_policy=self.custom_policy, ) + + self.set_resized_embedding_to_optimizer(model, optimizer, param_info) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ["fp16", "bf16"]: From 64379ab5479c293e5e7f792b2d24aa571eb28b29 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Sun, 17 Mar 2024 21:31:45 +0800 Subject: [PATCH 08/12] fix fix resize embedding fix resize embedding --- .../booster/plugin/hybrid_parallel_plugin.py | 43 ++---------------- .../shardformer/policies/base_policy.py | 45 +++++++++++++++++++ colossalai/shardformer/policies/gpt2.py | 5 ++- colossalai/shardformer/policies/llama.py | 14 +++--- .../test_model/test_shard_gpt2.py | 2 +- 5 files changed, 59 insertions(+), 50 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ed42694a166a..639f80cf8da2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -189,7 +189,7 @@ def unwrap(self): return module -def get_param_info(optim: Optimizer, model: torch.nn.Module): +def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A complete param_group, with params in the form of param_id # 2. A mapping from param address (obtained using id(param)) to integer param_id @@ -204,8 +204,6 @@ def get_param_info(optim: Optimizer, model: torch.nn.Module): "param2id": {}, "id2param": {}, "param2shape": {}, - "old_input_embedding_param_id": None, - "old_output_embedding_param_id": None, } start_index = 0 for group in optim.param_groups: @@ -222,13 +220,6 @@ def get_param_info(optim: Optimizer, model: torch.nn.Module): param_info["param_groups"].append(packed_group) start_index += len(group["params"]) - input_embedding = model.get_input_embeddings() - if input_embedding is not None: - param_info["old_input_embedding_param_id"] = id(input_embedding.weight) - output_embedding = model.get_output_embeddings() - if output_embedding is not None: - param_info["old_output_embedding_param_id"] = id(output_embedding.weight) - return param_info @@ -1081,7 +1072,7 @@ def __init__( overlap_communication=overlap_communication, cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), - # forced_dtype=PRECISION_TORCH_TYPE[precision], + forced_dtype=PRECISION_TORCH_TYPE[precision], ) self.max_norm = max_norm @@ -1090,32 +1081,6 @@ def __del__(self): """Destroy the process groups in ProcessGroupMesh""" self.pg_mesh.destroy_mesh_process_groups() - def set_resized_embedding_to_optimizer(self, model, optimizer, param_info): - old_input_embedding_param_id = param_info["old_input_embedding_param_id"] - if old_input_embedding_param_id is not None: - for param_group in optimizer.param_groups: - group_params = param_group["params"] - new_params = [] - for param in group_params: - if id(param) == old_input_embedding_param_id: - new_input_embeddings = model.module.get_input_embeddings() - new_params.append(new_input_embeddings.weight) - else: - new_params.append(param) - param_group["params"] = new_params - old_output_embedding_param_id = param_info["old_output_embedding_param_id"] - if old_output_embedding_param_id is not None: - for param_group in optimizer.param_groups: - group_params = param_group["params"] - new_params = [] - for param in group_params: - if id(param) == old_output_embedding_param_id: - new_output_embeddings = model.module.get_output_embeddings() - new_params.append(new_output_embeddings.weight) - else: - new_params.append(param) - param_group["params"] = new_params - @property def enable_pipeline_parallelism(self) -> bool: return self.pp_size > 1 @@ -1146,7 +1111,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - param_info = get_param_info(optimizer, model) + param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule( @@ -1159,8 +1124,6 @@ def configure( ddp_config=self.ddp_config, custom_policy=self.custom_policy, ) - - self.set_resized_embedding_to_optimizer(model, optimizer, param_info) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ["fp16", "bf16"]: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 1d2b7a570681..bbd8c91af4b2 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -5,9 +5,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import torch import torch.nn as nn from torch import Tensor from torch.nn import Module +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager @@ -243,3 +245,46 @@ def get_stage_index( stage_indices.append([start_idx, end_idx]) return stage_indices[0] if num_model_chunks == 1 else stage_indices + + + def resize_token_embeddings(self, model, new_num_tokens): + input_embeddings = self.model.get_input_embeddings() + if input_embeddings is not None: + self._resize_token_embeddings(model, input_embeddings, new_num_tokens) + output_embedddings = self.model.get_output_embeddings() + if output_embedddings is not None: + self._resize_lm_head(model, output_embedddings, new_num_tokens) + + def _resize_token_embeddings(self, model, embedding, new_num_tokens): + LazyInitContext.materialize(embedding) + old_num_tokens = embedding.num_embeddings + input_embedding_dim = embedding.embedding_dim + old_weight_data = embedding.weight.data + embedding.num_embeddings = new_num_tokens + if embedding.padding_idx is not None and embedding.padding_idx > new_num_tokens: + embedding.padding_idx = embedding.padding_idx - (old_num_tokens-new_num_tokens) + factory_kwargs = {'device': embedding.weight.device, 'dtype': embedding.weight.dtype} + embedding.weight.data = torch.empty((new_num_tokens, input_embedding_dim), **factory_kwargs) + embedding.reset_parameters() + model._init_weights(embedding) + # Copy token embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + embedding.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :] + + def _resize_lm_head(self, model, lm_head, new_num_tokens): + LazyInitContext.materialize(lm_head) + old_num_tokens, lm_head_dim = (lm_head.weight.size()) + old_weight_data = lm_head.weight.data + old_bias_data = lm_head.bias.data if lm_head.bias is not None else None + lm_head.out_features = new_num_tokens + factory_kwargs = {'device': lm_head.weight.device, 'dtype': lm_head.weight.dtype} + lm_head.weight.data = torch.empty((new_num_tokens, lm_head_dim), **factory_kwargs) + if lm_head.bias is not None: + lm_head.bias.data = torch.empty(new_num_tokens, **factory_kwargs) + lm_head.reset_parameters() + model._init_weights(lm_head) + # Copy token embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + lm_head.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :] + if lm_head.bias is not None: + lm_head.bias.data[:num_tokens_to_copy] = old_bias_data[:num_tokens_to_copy] diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 9d3ebed2540c..0a4bd5bd48b4 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,4 +1,5 @@ from functools import partial +import math from typing import Callable, Dict, List from torch import Tensor, nn @@ -36,10 +37,10 @@ def preprocess(self): multiple = self.shard_config.make_vocab_size_divisible_by if self.shard_config.enable_tensor_parallelism: world_size = self.shard_config.tensor_parallel_size - multiple = multiple * world_size + multiple = multiple * world_size // (math.gcd(multiple, world_size)) if vocab_size % multiple != 0: new_vocab_size = vocab_size + multiple - vocab_size % multiple - self.model.resize_token_embeddings(new_vocab_size) + self.resize_token_embeddings(self.model, new_vocab_size) return self.model def module_policy(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 42bf0825b045..fbe03d72b030 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,4 +1,5 @@ import warnings +import math from functools import partial from typing import Callable, Dict, List, Union @@ -23,15 +24,14 @@ def config_sanity_check(self): pass def preprocess(self): + vocab_size = self.model.config.vocab_size + multiple = self.shard_config.make_vocab_size_divisible_by if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + multiple = multiple * world_size // (math.gcd(multiple, world_size)) + if vocab_size % multiple != 0: + new_vocab_size = vocab_size + multiple - vocab_size % multiple + self.resize_token_embeddings(self.model, new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3155420f1cf2..919f23404294 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -43,7 +43,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config["precision"] == "fp32": - atol, rtol = 1e-4, 1e-3 + atol, rtol = 2e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( From 59fd36073d352234015e7e3663ebeaf2b5a76241 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Sun, 17 Mar 2024 23:09:04 +0800 Subject: [PATCH 09/12] fix resize embedding fix --- .../test_booster/test_plugin/test_3d_plugin.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 285c4866c441..841fae062b68 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -17,6 +17,8 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from tests.kit.model_zoo import model_zoo +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.checkpoint_io.utils import gather_distributed_param class RandomDataset(Dataset): @@ -255,12 +257,15 @@ def run_grad_acc_test(test_args): optimizer.step() optimizer.zero_grad() - # tricky code here, shard the origin model inorder to check the parameters in the same stage. - origin_model, origin_optimizer, _, dataloader, _ = booster.boost( - origin_model, origin_optimizer, dataloader=dataloader - ) - for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + if booster.plugin.stage_manager is None or booster.plugin.stage_manager.is_first_stage(): + for p1, p2 in zip(model.unwrap().parameters(), origin_model.parameters()): + if is_distributed_tensor(p1) or is_customized_distributed_tensor(p1): + p1 = gather_distributed_param(p1, keep_vars=False) + if p1.dim() > 1: + assert_close(p1.to(p2.dtype)[: p2.shape[0], :], p2, atol=1e-2, rtol=1e-2) + else: + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + def run_dist(rank, world_size, port, early_stop: bool = True): From 958be8bc58a9af4cefc7d899992dcc899f55bd10 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Mar 2024 14:55:28 +0800 Subject: [PATCH 10/12] revert --- .../booster/plugin/hybrid_parallel_plugin.py | 3 -- .../shardformer/policies/base_policy.py | 47 +------------------ colossalai/shardformer/policies/gpt2.py | 10 ++-- colossalai/shardformer/policies/llama.py | 14 +++--- colossalai/shardformer/shard/shard_config.py | 3 +- .../test_plugin/test_3d_plugin.py | 16 +++---- .../test_model/test_shard_gpt2.py | 2 +- 7 files changed, 22 insertions(+), 73 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 639f80cf8da2..c37a6b4df72d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -931,7 +931,6 @@ class HybridParallelPlugin(PipelinePluginBase): pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. - make_vocab_size_divisible_by (bool, optional): make the vocabulary size is divisible by `make_vocab_size_divisible_by`, to select a faster CUDA kernel operator. Default to 128. """ def __init__( @@ -971,7 +970,6 @@ def __init__( pp_style: str = "1f1b", num_model_chunks: int = 1, enable_metadata_cache: bool = True, - make_vocab_size_divisible_by: int = 128, ) -> None: super().__init__() assert ( @@ -1045,7 +1043,6 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, - make_vocab_size_divisible_by=make_vocab_size_divisible_by, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index bbd8c91af4b2..9a49b1ba6a14 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -5,11 +5,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import torch import torch.nn as nn from torch import Tensor from torch.nn import Module -from colossalai.lazy.lazy_init import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager @@ -244,47 +242,4 @@ def get_stage_index( end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] stage_indices.append([start_idx, end_idx]) - return stage_indices[0] if num_model_chunks == 1 else stage_indices - - - def resize_token_embeddings(self, model, new_num_tokens): - input_embeddings = self.model.get_input_embeddings() - if input_embeddings is not None: - self._resize_token_embeddings(model, input_embeddings, new_num_tokens) - output_embedddings = self.model.get_output_embeddings() - if output_embedddings is not None: - self._resize_lm_head(model, output_embedddings, new_num_tokens) - - def _resize_token_embeddings(self, model, embedding, new_num_tokens): - LazyInitContext.materialize(embedding) - old_num_tokens = embedding.num_embeddings - input_embedding_dim = embedding.embedding_dim - old_weight_data = embedding.weight.data - embedding.num_embeddings = new_num_tokens - if embedding.padding_idx is not None and embedding.padding_idx > new_num_tokens: - embedding.padding_idx = embedding.padding_idx - (old_num_tokens-new_num_tokens) - factory_kwargs = {'device': embedding.weight.device, 'dtype': embedding.weight.dtype} - embedding.weight.data = torch.empty((new_num_tokens, input_embedding_dim), **factory_kwargs) - embedding.reset_parameters() - model._init_weights(embedding) - # Copy token embeddings from the previous weights - num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - embedding.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :] - - def _resize_lm_head(self, model, lm_head, new_num_tokens): - LazyInitContext.materialize(lm_head) - old_num_tokens, lm_head_dim = (lm_head.weight.size()) - old_weight_data = lm_head.weight.data - old_bias_data = lm_head.bias.data if lm_head.bias is not None else None - lm_head.out_features = new_num_tokens - factory_kwargs = {'device': lm_head.weight.device, 'dtype': lm_head.weight.dtype} - lm_head.weight.data = torch.empty((new_num_tokens, lm_head_dim), **factory_kwargs) - if lm_head.bias is not None: - lm_head.bias.data = torch.empty(new_num_tokens, **factory_kwargs) - lm_head.reset_parameters() - model._init_weights(lm_head) - # Copy token embeddings from the previous weights - num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - lm_head.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :] - if lm_head.bias is not None: - lm_head.bias.data[:num_tokens_to_copy] = old_bias_data[:num_tokens_to_copy] + return stage_indices[0] if num_model_chunks == 1 else stage_indices \ No newline at end of file diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 0a4bd5bd48b4..8dc5ea7c0dd9 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -33,14 +33,12 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - multiple = self.shard_config.make_vocab_size_divisible_by if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size - multiple = multiple * world_size // (math.gcd(multiple, world_size)) - if vocab_size % multiple != 0: - new_vocab_size = vocab_size + multiple - vocab_size % multiple - self.resize_token_embeddings(self.model, new_vocab_size) + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index fbe03d72b030..2c01d9891418 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -24,14 +24,16 @@ def config_sanity_check(self): pass def preprocess(self): - vocab_size = self.model.config.vocab_size - multiple = self.shard_config.make_vocab_size_divisible_by + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size - multiple = multiple * world_size // (math.gcd(multiple, world_size)) - if vocab_size % multiple != 0: - new_vocab_size = vocab_size + multiple - vocab_size % multiple - self.resize_token_embeddings(self.model, new_vocab_size) + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 49ea382f5cfc..da27341d9c29 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -35,8 +35,9 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True - make_vocab_size_divisible_by: int = 128 extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # TODO padding vocab + # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 841fae062b68..f9c86d5cad3c 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -17,8 +17,6 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from tests.kit.model_zoo import model_zoo -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.checkpoint_io.utils import gather_distributed_param class RandomDataset(Dataset): @@ -257,14 +255,12 @@ def run_grad_acc_test(test_args): optimizer.step() optimizer.zero_grad() - if booster.plugin.stage_manager is None or booster.plugin.stage_manager.is_first_stage(): - for p1, p2 in zip(model.unwrap().parameters(), origin_model.parameters()): - if is_distributed_tensor(p1) or is_customized_distributed_tensor(p1): - p1 = gather_distributed_param(p1, keep_vars=False) - if p1.dim() > 1: - assert_close(p1.to(p2.dtype)[: p2.shape[0], :], p2, atol=1e-2, rtol=1e-2) - else: - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + # tricky code here, shard the origin model inorder to check the parameters in the same stage. + origin_model, origin_optimizer, _, dataloader, _ = booster.boost( + origin_model, origin_optimizer, dataloader=dataloader + ) + for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 919f23404294..3155420f1cf2 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -43,7 +43,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config["precision"] == "fp32": - atol, rtol = 2e-4, 1e-3 + atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( From 1d37b0675735b756353abd23dcd01389e533c865 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Mar 2024 15:01:33 +0800 Subject: [PATCH 11/12] revert --- colossalai/shardformer/policies/gpt2.py | 8 +++----- colossalai/shardformer/policies/llama.py | 7 +++---- tests/test_booster/test_plugin/test_3d_plugin.py | 3 +-- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 8dc5ea7c0dd9..4240727dac66 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,5 +1,4 @@ from functools import partial -import math from typing import Callable, Dict, List from torch import Tensor, nn @@ -29,16 +28,15 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ if self.shard_config.enable_tensor_parallelism: + # Resize embedding vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) + return self.model def module_policy(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2c01d9891418..a729a6c1bca3 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -24,16 +24,15 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ if self.shard_config.enable_tensor_parallelism: + # Resize embedding vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) + return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index f9c86d5cad3c..38361d803c49 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -260,8 +260,7 @@ def run_grad_acc_test(test_args): origin_model, origin_optimizer, dataloader=dataloader ) for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) - + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) def run_dist(rank, world_size, port, early_stop: bool = True): From 068c15f04b735b35b4990c4d1154dd0c16f578d3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Mar 2024 15:04:16 +0800 Subject: [PATCH 12/12] revert --- colossalai/shardformer/policies/gpt2.py | 7 ++++--- colossalai/shardformer/policies/llama.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 4240727dac66..303766993e3d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -28,15 +28,16 @@ def config_sanity_check(self): pass def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ if self.shard_config.enable_tensor_parallelism: - # Resize embedding vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) - return self.model def module_policy(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a729a6c1bca3..42bf0825b045 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,5 +1,4 @@ import warnings -import math from functools import partial from typing import Callable, Dict, List, Union