From 739af907d6770838e348831d286b521eca693f54 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Thu, 21 Mar 2024 12:01:45 +0800 Subject: [PATCH 01/44] flash_attention forward upgrade --- colossalai/shardformer/modeling/llama.py | 29 ++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index eb8e9f748527..533290538c9a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -11,6 +11,7 @@ ) from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel from transformers.utils import logging +from transformers.cache_utils import Cache from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig @@ -438,11 +439,15 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." @@ -452,23 +457,23 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - if llama_version == 2: - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) From 976396c03ee5c3cb59b38ecf1ae7bb53ab729839 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 10:36:25 +0800 Subject: [PATCH 02/44] llama_model_forward --- colossalai/shardformer/modeling/llama.py | 65 ++++++++++++------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 533290538c9a..b9b1cc1424dc 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -20,7 +20,7 @@ from ..layer._operation import gather_forward_split_backward try: - from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa LATEST_VERSION = True except ImportError: @@ -63,13 +63,13 @@ def llama_model_forward( # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape + batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape + batch_size, seq_length, _ = inputs_embeds.shape[:2] else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -101,24 +101,28 @@ def llama_model_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + position_ids = position_ids.unsqueeze(0) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) if LATEST_VERSION: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) - else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + if self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -130,37 +134,32 @@ def llama_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None start_idx, end_idx = stage_index[0], stage_index[1] for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + #past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, - None, + past_key_values, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -168,7 +167,7 @@ def custom_forward(*inputs): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) From 63ef374bdc49d54f3267c9c3af3fb1e23b202943 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 14:24:57 +0800 Subject: [PATCH 03/44] remove useless comment --- colossalai/shardformer/modeling/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b9b1cc1424dc..f88bbec27c43 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,8 +141,6 @@ def llama_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - #past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( From b00f9ea2dbfb161f17a226a058ea0441dfaa049e Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 17:12:37 +0800 Subject: [PATCH 04/44] update the requirements.txt --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 4136cefc3c37..d42cb08a792f 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -3,7 +3,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers==4.33.0 +transformers==4.36.0 timm titans torchaudio From dc8b9d46f1ea40c80c043953e848337c8025ad48 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 17:52:35 +0800 Subject: [PATCH 05/44] add the transformers version requirements --- requirements/requirements-test.txt | 1 - requirements/requirements.txt | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index d42cb08a792f..0b15b9311937 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -3,7 +3,6 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers==4.36.0 timm titans torchaudio diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 095617d76355..38b8f66a8f8b 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,3 +16,4 @@ ray sentencepiece google protobuf +transformers==4.36.0 From 9206dd1c5b332f272f853da2a247757a5a2c8a06 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Tue, 26 Mar 2024 15:50:36 +0800 Subject: [PATCH 06/44] remove the LATEST VERSION try --- colossalai/shardformer/modeling/llama.py | 34 +++++++++--------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f88bbec27c43..5e7540b96598 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -19,12 +19,7 @@ from ..layer import cross_entropy_1d from ..layer._operation import gather_forward_split_backward -try: - from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa - - LATEST_VERSION = True -except ImportError: - LATEST_VERSION = False +from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa class LlamaPipelineForwards: @@ -106,23 +101,20 @@ def llama_model_forward( if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage - if LATEST_VERSION: - if self._use_sdpa and not output_attentions: + elif self._use_sdpa and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: From f1ebe544e70bd4f418fa1aa22598e64101f53acd Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Mon, 1 Apr 2024 15:36:27 +0800 Subject: [PATCH 07/44] [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction --- colossalai/shardformer/modeling/bloom.py | 39 +++++++++--------------- colossalai/shardformer/policies/bloom.py | 6 ---- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index d94c30d29e71..5c4e8d1cb703 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -21,7 +21,7 @@ BloomModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig @@ -205,12 +205,13 @@ def bloom_model_forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) - + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: @@ -226,21 +227,15 @@ def bloom_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( @@ -1000,11 +995,13 @@ def forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( @@ -1016,21 +1013,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index eddfafdcbcdc..4fb03c83051f 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -23,12 +23,6 @@ class BloomPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Bloom model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass From 2cdca4d56ead44a10d68761fc1a7e5fdb1803309 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 11:55:52 +0800 Subject: [PATCH 08/44] [shardformer] update_falcon (#5520) --- colossalai/shardformer/modeling/falcon.py | 286 +++++++++++++++------- colossalai/shardformer/policies/falcon.py | 6 - 2 files changed, 201 insertions(+), 91 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 4e271dfe0fa2..49e9564d8773 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,9 +1,14 @@ from typing import List, Optional, Tuple, Union - +import math import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -18,11 +23,13 @@ FalconForTokenClassification, FalconModel, build_alibi_tensor, + apply_rotary_pos_emb, ) from transformers.utils import logging - +import warnings from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig +from torch.nn import functional as F def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -99,11 +106,17 @@ def forward( hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states if self.config.new_decoder_architecture: @@ -117,10 +130,12 @@ def forward( attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, + position_ids=position_ids, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + **kwargs, ) attention_output = attn_outputs[0] @@ -166,11 +181,17 @@ def forward( hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -178,59 +199,111 @@ def forward( batch_size, query_length, _, _ = query_layer.shape - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape( - batch_size * num_kv_heads, - query_length, - self.head_dim, - ) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + kv_seq_len += layer_past[0].shape[-2] + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) + # - key: [batch_size, self.num_heads, kv_length, head_dim] + # - value: [batch_size, self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=-2) + value_layer = torch.cat((past_value, value_layer), dim=-2) - _, kv_length, _ = key_layer.shape + kv_length = key_layer.shape[-2] if use_cache: present = (key_layer, value_layer) else: present = None - attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + if alibi is None: + if self._use_sdpa and not output_attentions: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attention_scores = None + else: + attention_scores = query_layer @ key_layer.transpose(-1, -2) + attention_scores /= math.sqrt(self.head_dim) + + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). + attn_output = attention_scores @ value_layer - query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous() - key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() - value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - if alibi is not None: - attention_mask_float = ( - attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta - ) + attn_output = self.dense(attn_output) - batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1] - tgt_len = key_layer_.size()[1] - attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous() - context_layer = me_attention( - query_layer_, - key_layer_, - value_layer_, - attn_bias=attention_mask_float, - scale=self.inv_norm_factor, - p=self.attention_dropout.p, - ) - batch_size, seq_length, _, _ = context_layer.shape - context_layer = context_layer.reshape(batch_size, seq_length, -1) + if output_attentions: + return attn_output, present, attention_scores + else: + return attn_output, present + else: + if self._use_sdpa and not output_attentions and head_mask is None: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + attn_output = self.dense(attn_output) + else: + matmul_result = query_layer @ key_layer.transpose(-1, -2) - output_tensor = self.dense(context_layer) + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - return output_tensor, present + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + attn_output = self._merge_heads(attn_output) + + attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_probs + else: + return attn_output, present return forward @@ -246,6 +319,7 @@ def falcon_model_forward( input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -273,18 +347,7 @@ def falcon_model_forward( past_key_values = None return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - else: - past_key_values = self._convert_to_rw_cache(past_key_values) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - + # case: First stage of training if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -295,16 +358,22 @@ def falcon_model_forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - hidden_states = inputs_embeds - else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -312,22 +381,80 @@ def falcon_model_forward( # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) + past_key_values_length = past_key_values[0][0].shape[-2] if self.use_alibi: - alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + mask = ( + torch.ones( + (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long + ) + if attention_mask is None + else attention_mask + ) + alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if alibi is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + elif head_mask is None: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) + attention_mask_2d = attention_mask + # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # We take care to integrate alibi bias in the attention_mask here. + if attention_mask_2d is None: + attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) + else: + attention_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + attention_mask < -1, + torch.finfo(alibi.dtype).min, + ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if seq_length > 1: + attention_mask = AttentionMaskConverter._unmask_unattended( + attention_mask, attention_mask_2d, unmasked_value=0.0 + ) + else: + # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate( @@ -337,31 +464,23 @@ def falcon_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, - causal_mask, + attention_mask, + position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, ) else: outputs = block( hidden_states, layer_past=layer_past, - attention_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, @@ -382,9 +501,6 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if presents is not None: - presents = self._convert_cache_to_standard_format(presents, batch_size) - if stage_manager.is_last_stage(): if not return_dict: return tuple( diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 5c148880f980..1d9eb81c9e65 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -21,12 +21,6 @@ class FalconPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Falcon model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass From 7686f4e2e078849c6997e646bb28d64d4c9a382f Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:06:57 +0800 Subject: [PATCH 09/44] [shardformer] update mistral model (#5511) --- colossalai/shardformer/modeling/mistral.py | 171 +++++++++++++++++++-- colossalai/shardformer/policies/mistral.py | 18 ++- 2 files changed, 174 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 0da1a35a0278..c325cb284c22 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -1,7 +1,149 @@ from typing import Optional, Tuple import torch +from transformers.modeling_outputs import BaseModelOutputWithPast +from typing import List, Optional, Tuple, Union +import warnings +from transformers.models.mistral.modeling_mistral import MistralModel +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.utils import logging +from transformers.cache_utils import Cache + +logger = logging.get_logger(__name__) + +class MistralForwards: + + @staticmethod + def mistral_model_forward( + self:MistralModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if use_cache: + logger.warning_once("use_cache=True is not supported for Mistral models at the moment.") + use_cache = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv @@ -13,10 +155,15 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." @@ -30,18 +177,19 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -68,6 +216,9 @@ def forward( attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value return forward diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c0b8b3375836..31ce160463a6 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,11 +1,12 @@ import warnings -from typing import Dict, Union +from functools import partial +from typing import Dict, Union, Callable import torch.nn as nn from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.mistral import get_mistral_flash_attention_forward +from ..modeling.mistral import get_mistral_flash_attention_forward, MistralForwards from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] @@ -128,6 +129,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model + + def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + method_replacement = { + "forward": partial(new_forward) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) class MistralModelPolicy(MistralPolicy): @@ -135,10 +142,11 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") + policy = super().module_policy() + from transformers.models.mistral.modeling_mistral import MistralModel - return super().module_policy() + self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy) + return policy class MistralForCausalLMPolicy(MistralPolicy): From fd4444058f9ebd5f99cfc60e2e5bf69a7dd38d73 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:07:46 +0800 Subject: [PATCH 10/44] [shardformer] update gpt2 (#5502) --- colossalai/shardformer/modeling/gpt2.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 1e22d9094eae..fdea27479773 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -139,11 +139,9 @@ def gpt2_model_forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - else: + if position_ids is None: position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -188,22 +186,16 @@ def gpt2_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( From 9a5edc3f18632766995055de207ff88067535d09 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:18:20 +0800 Subject: [PATCH 11/44] [shardformer] update gptj model (#5503) --- colossalai/shardformer/modeling/gptj.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 1990d7df3279..187e35e40dd4 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -123,11 +123,9 @@ def gptj_model_forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) # position id to be assigned not just for the first stage for attn input - if position_ids is not None: - position_ids = position_ids.view(-1, seq_length) - else: + if position_ids is None: position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + position_ids = position_ids.unsqueeze(0) if stage_manager.is_first_stage(): if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -172,21 +170,15 @@ def gptj_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( @@ -603,7 +595,9 @@ def forward( value = torch.cat((past_value, value), dim=1) if use_cache is True: - present = (key, value) + # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. + # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 + present = (key.to(hidden_states.dtype), value) else: present = None From cbff8c0dd0378ece21dad3f95a03ecb46e9a8a23 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:23:03 +0800 Subject: [PATCH 12/44] [shardformer] update opt (#5522) --- colossalai/shardformer/modeling/opt.py | 75 +++++++++----------------- colossalai/shardformer/policies/opt.py | 6 --- 2 files changed, 26 insertions(+), 55 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index d0e267eacd25..095c8c715f84 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,6 +1,5 @@ import random from typing import List, Optional, Tuple, Union - import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( @@ -16,7 +15,7 @@ OPTModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from colossalai.pipeline.stage_manager import PipelineStageManager @@ -25,33 +24,7 @@ class OPTPipelineForwards: This class serves as a micro library for forward function substitution of OPT models under pipeline setting. """ - - @staticmethod - def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - from transformers.models.opt.modeling_opt import _make_causal_mask - - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - _dtype, - device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to( - device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - + @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ @@ -120,6 +93,7 @@ def opt_model_forward( inputs_embeds = decoder.project_in(inputs_embeds) device = input_ids.device if input_ids is not None else inputs_embeds.device _dtype = inputs_embeds.dtype + hidden_states = inputs_embeds else: if hidden_states is None: @@ -133,17 +107,26 @@ def opt_model_forward( # required mask seq length can be calculated via length of past mask_seq_length = past_key_values_length + seq_length # embed positions - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)" + if self.decoder._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length ) - - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask( - attention_mask, input_shape, _dtype, device, past_key_values_length - ) if stage_manager.is_first_stage(): pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) @@ -202,20 +185,14 @@ def opt_model_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if decoder.gradient_checkpointing and decoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index a542808ba794..c7b9853e5c37 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -24,12 +24,6 @@ class OPTPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The OPT model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass From 46479fbe6304ad6e772d04cfb2178e8ea3599104 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:23:37 +0800 Subject: [PATCH 13/44] [shardformer] update t5 model (#5524) --- colossalai/shardformer/modeling/t5.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 9c5ce3fb65c9..94f4fce74501 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -118,15 +118,12 @@ def t5_stack_forward( # required mask seq length can be calculated via length of past mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long) - # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -138,7 +135,9 @@ def t5_stack_forward( encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -162,15 +161,8 @@ def t5_stack_forward( torch.cuda.set_device(hidden_states.device) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -180,6 +172,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( From d7af2d8aab2b1fe981e7d6735beaa7be69fc1299 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:25:22 +0800 Subject: [PATCH 14/44] [shardformer] update whisper model (#5529) --- colossalai/shardformer/modeling/whisper.py | 46 +++++++++++----------- colossalai/shardformer/policies/whisper.py | 6 --- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index cb8b45ae7d01..8a7e9cd0cf5d 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -19,7 +19,7 @@ WhisperModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from colossalai.pipeline.stage_manager import PipelineStageManager @@ -369,18 +369,12 @@ def whisper_encoder_forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -528,6 +522,20 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # embed positions if input_ids is not None: @@ -535,10 +543,6 @@ def whisper_decoder_forward( else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -575,16 +579,8 @@ def whisper_decoder_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -592,6 +588,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index b5b5db79d9de..df6194a6d2d0 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -28,12 +28,6 @@ class WhisperPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Whisper model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass From 02d9b88517052ae4c51529f3f156fafc6f3b166e Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 14:05:41 +0800 Subject: [PATCH 15/44] [shardformer] update vit model (#5530) * update vit model * remove the output_hidden_states --- colossalai/shardformer/modeling/vit.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index ab141a74aef8..80d74c2960ac 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -14,6 +14,8 @@ def _encoder_forward( end_idx: int, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, return_dict: bool = True, stage_manager: PipelineStageManager = None, ) -> Union[tuple, BaseModelOutput]: @@ -23,20 +25,14 @@ def _encoder_forward( layer_head_mask = head_mask[i] if head_mask is not None else None if encoder.gradient_checkpointing and encoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, False) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - layer_head_mask, - ) + layer_outputs = encoder._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) else: - layer_outputs = layer_module(hidden_states, layer_head_mask, False) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if not stage_manager.is_last_stage(): @@ -112,6 +108,8 @@ def pp_forward( end_idx=stage_index[1], hidden_states=hidden_states, head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, stage_manager=stage_manager, ) From c3e821582dbfda3c59721c2818c30e10322d4fb5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 12 Apr 2024 13:14:14 +0800 Subject: [PATCH 16/44] [shardformer] fix llama modeling --- colossalai/shardformer/modeling/llama.py | 99 ++++++++++++++++-------- 1 file changed, 66 insertions(+), 33 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index dd2caefc5054..53332726d10e 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,6 +7,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -16,11 +17,12 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, apply_rotary_pos_emb, repeat_kv, ) from transformers.utils import logging -from transformers.cache_utils import Cache from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( @@ -32,8 +34,6 @@ from ..layer import ColoAttention, cross_entropy_1d -from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa - class LlamaPipelineForwards: """ @@ -107,7 +107,10 @@ def llama_model_forward( if position_ids is None: position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0) @@ -117,26 +120,33 @@ def llama_model_forward( # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, ) else: if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -159,7 +169,7 @@ def llama_model_forward( num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, num_layers=end_idx - start_idx, - model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), ) assert num_ckpt_layers <= end_idx - start_idx @@ -203,7 +213,16 @@ def llama_model_forward( next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -307,7 +326,9 @@ def llama_for_causal_lm_forward( new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -446,12 +467,10 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - llama_version = 2 try: from transformers.models.llama.modeling_llama import repeat_kv except: warnings.warn("using llamav1, llamav1 hasn't repeat_kv function") - llama_version = 1 def forward( self: LlamaAttention, @@ -494,8 +513,8 @@ def forward( raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) + "with a layer index." + ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -567,7 +586,10 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -581,7 +603,11 @@ def forward( # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, ) if self.gradient_checkpointing and self.training: @@ -736,7 +762,9 @@ def forward( new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, ) if not return_dict: @@ -910,7 +938,10 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -926,7 +957,9 @@ def forward( if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, ) attention_mask = self._prepare_decoder_attention_mask( From 8b72eabfe4547b759cfce157119c0a63a3d1931d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 05:15:30 +0000 Subject: [PATCH 17/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/bloom.py | 3 ++- colossalai/shardformer/modeling/falcon.py | 16 +++++++++------- colossalai/shardformer/modeling/mistral.py | 19 +++++++++---------- colossalai/shardformer/modeling/opt.py | 9 +++++---- colossalai/shardformer/modeling/t5.py | 7 ++----- colossalai/shardformer/modeling/vit.py | 10 +++++----- colossalai/shardformer/modeling/whisper.py | 8 ++++++-- colossalai/shardformer/policies/mistral.py | 10 ++++------ 8 files changed, 42 insertions(+), 40 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 2b2bf89a06e2..c4f326364596 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -6,6 +6,7 @@ from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -21,7 +22,7 @@ BloomModel, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 49e9564d8773..34754ecdbac9 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,9 +1,12 @@ -from typing import List, Optional, Tuple, Union import math +import warnings +from typing import List, Optional, Tuple, Union + import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, @@ -22,14 +25,13 @@ FalconForSequenceClassification, FalconForTokenClassification, FalconModel, - build_alibi_tensor, apply_rotary_pos_emb, + build_alibi_tensor, ) from transformers.utils import logging -import warnings + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from torch.nn import functional as F def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -171,7 +173,7 @@ def forward( def get_falcon_flash_attention_forward(): try: - from xformers.ops import memory_efficient_attention as me_attention + pass except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.falcon.modeling_falcon import FalconAttention @@ -347,7 +349,7 @@ def falcon_model_forward( past_key_values = None return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + # case: First stage of training if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -449,7 +451,7 @@ def falcon_model_forward( attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) - + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index c325cb284c22..3b876bcab96a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -1,21 +1,20 @@ -from typing import Optional, Tuple +import warnings +from typing import List, Optional, Tuple, Union import torch +from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from typing import List, Optional, Tuple, Union -import warnings from transformers.models.mistral.modeling_mistral import MistralModel -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.utils import logging -from transformers.cache_utils import Cache logger = logging.get_logger(__name__) + class MistralForwards: - @staticmethod def mistral_model_forward( - self:MistralModel, + self: MistralModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -94,7 +93,6 @@ def mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -123,7 +121,7 @@ def mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -145,6 +143,7 @@ def mistral_model_forward( attentions=all_self_attns, ) + def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv @@ -218,7 +217,7 @@ def forward( if not output_attentions: attn_weights = None - + return attn_output, attn_weights, past_key_value return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 0a31820876ad..de5b1a267cd7 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,7 +1,9 @@ import random from typing import List, Optional, Tuple, Union + import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -15,7 +17,7 @@ OPTModel, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -55,7 +57,7 @@ class OPTPipelineForwards: This class serves as a micro library for forward function substitution of OPT models under pipeline setting. """ - + @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ @@ -70,7 +72,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - @staticmethod def opt_model_forward( self: OPTModel, @@ -125,7 +126,7 @@ def opt_model_forward( if decoder.project_in is not None: inputs_embeds = decoder.project_in(inputs_embeds) device = input_ids.device if input_ids is not None else inputs_embeds.device - _dtype = inputs_embeds.dtype + inputs_embeds.dtype hidden_states = inputs_embeds else: if hidden_states is None: diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 94f4fce74501..b35bb6b94991 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -3,7 +3,6 @@ import torch from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -121,7 +120,7 @@ def t5_stack_forward( # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) - + if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=device) @@ -135,9 +134,7 @@ def t5_stack_forward( encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long - ) + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 401973ce4dfe..67b10988d100 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -26,11 +26,11 @@ def _encoder_forward( if encoder.gradient_checkpointing and encoder.training: layer_outputs = encoder._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 6997f181c9ee..509fc3dac86f 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -5,6 +5,10 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -21,7 +25,7 @@ shift_tokens_right, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -695,7 +699,7 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 31ce160463a6..3645cf3694fa 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,12 +1,12 @@ import warnings from functools import partial -from typing import Dict, Union, Callable +from typing import Callable, Dict, Union import torch.nn as nn from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.mistral import get_mistral_flash_attention_forward, MistralForwards +from ..modeling.mistral import MistralForwards, get_mistral_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] @@ -129,11 +129,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model - + def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = { - "forward": partial(new_forward) - } + method_replacement = {"forward": partial(new_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) From 46b90f7ff9b01f796813749b4da345f4ad3a1258 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 16 Apr 2024 17:49:21 +0800 Subject: [PATCH 18/44] [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements --- .../low_level/bookkeeping/bucket_store.py | 2 + colossalai/zero/low_level/low_level_optim.py | 67 +++++++++++++++---- requirements/requirements.txt | 2 +- 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index f395fc60ec42..2ebc704f74e6 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -11,7 +11,9 @@ class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) + self.reset_all() + def reset_all(self) -> None: # init self.current_group_id = 0 self._num_elements_in_bucket = 0 diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bbbaf13b53ef..cbcf7269738c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -40,7 +40,13 @@ def __init__( max_scale: float = 2**32, ) -> None: super().__init__( - initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, ) self.num_working_param_groups = num_working_param_groups self.grad_store = grad_store @@ -273,11 +279,10 @@ def _create_master_param_current_rank(self, param_list): # Backward Reduction Hook # ########################### - def _grad_handler(self, param, group_id, grad): + def _grad_handler(self, group_id, param): # if run with no_sync context, would not sync grad when backward if self.require_grad_sync: self._add_to_bucket(param, group_id) - return grad def _attach_reduction_hook(self): # we iterate over the working params @@ -286,7 +291,7 @@ def _attach_reduction_hook(self): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - param.register_hook(partial(self._grad_handler, param, group_id)) + param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id)) ####################### # Reduction Functions # @@ -415,7 +420,10 @@ def _run_reduction(self): recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) self._update_partitoned_grad( - non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1 + non_moe_grad_in_bucket_current_rank, + recieved_grad, + group_id, + 1, ) if len(moe_grad_list) > 0: @@ -423,7 +431,11 @@ def _run_reduction(self): moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) + dist.reduce_scatter( + recieved_grad, + flat_grads_list, + group=self.moe_extra_dp_pg, + ) param_slice = self._world_size // self.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) for split_recieved_grad in recieved_grad: @@ -444,14 +456,25 @@ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List self._add_grad(grad, self._world_size, group_id, param_id, rank) def _update_partitoned_grad( - self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int + self, + origin_grad_list: List, + flat_grad: torch.Tensor, + group_id: int, + partition_num: int, ) -> None: sync_tensor(flat_grad, origin_grad_list) for grad in origin_grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) self._add_grad(grad, partition_num, group_id, param_id) - def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: + def _add_grad( + self, + grad: torch.Tensor, + partition_num: int, + group_id: int, + param_id: int, + rank: int = 0, + ) -> None: if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: @@ -534,6 +557,7 @@ def zero_grad(self, set_to_none=True): if param.grad is not None: param.grad.detach() param.grad.zero_() + self._bucket_store.reset_all() #################### # Update Parameter # @@ -655,14 +679,20 @@ def step(self, closure=None): for _ in range(self.moe_extra_dp_pg_size) ] dist.all_gather( - all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.moe_extra_dp_pg, ) else: all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) + dist.all_gather( + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.dp_pg, + ) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -685,7 +715,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -698,10 +730,14 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) torch.distributed.all_reduce( - total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -920,5 +956,8 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: if hasattr(self, "moe_master_to_working_map"): - return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} + return { + **self._param_store.master_to_working_param, + **self.moe_master_to_working_map, + } return self._param_store.master_to_working_param diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 38b8f66a8f8b..b0352230788a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=1.12 +torch>=2.1.0 safetensors einops pydantic From b15b9644ba3d461755c13944d417f472f8c4d96a Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 16 Apr 2024 17:49:21 +0800 Subject: [PATCH 19/44] [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements --- .../low_level/bookkeeping/bucket_store.py | 2 + colossalai/zero/low_level/low_level_optim.py | 67 +++++++++++++++---- requirements/requirements.txt | 2 +- 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index f395fc60ec42..2ebc704f74e6 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -11,7 +11,9 @@ class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) + self.reset_all() + def reset_all(self) -> None: # init self.current_group_id = 0 self._num_elements_in_bucket = 0 diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bbbaf13b53ef..cbcf7269738c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -40,7 +40,13 @@ def __init__( max_scale: float = 2**32, ) -> None: super().__init__( - initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, ) self.num_working_param_groups = num_working_param_groups self.grad_store = grad_store @@ -273,11 +279,10 @@ def _create_master_param_current_rank(self, param_list): # Backward Reduction Hook # ########################### - def _grad_handler(self, param, group_id, grad): + def _grad_handler(self, group_id, param): # if run with no_sync context, would not sync grad when backward if self.require_grad_sync: self._add_to_bucket(param, group_id) - return grad def _attach_reduction_hook(self): # we iterate over the working params @@ -286,7 +291,7 @@ def _attach_reduction_hook(self): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - param.register_hook(partial(self._grad_handler, param, group_id)) + param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id)) ####################### # Reduction Functions # @@ -415,7 +420,10 @@ def _run_reduction(self): recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) self._update_partitoned_grad( - non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1 + non_moe_grad_in_bucket_current_rank, + recieved_grad, + group_id, + 1, ) if len(moe_grad_list) > 0: @@ -423,7 +431,11 @@ def _run_reduction(self): moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) + dist.reduce_scatter( + recieved_grad, + flat_grads_list, + group=self.moe_extra_dp_pg, + ) param_slice = self._world_size // self.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) for split_recieved_grad in recieved_grad: @@ -444,14 +456,25 @@ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List self._add_grad(grad, self._world_size, group_id, param_id, rank) def _update_partitoned_grad( - self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int + self, + origin_grad_list: List, + flat_grad: torch.Tensor, + group_id: int, + partition_num: int, ) -> None: sync_tensor(flat_grad, origin_grad_list) for grad in origin_grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) self._add_grad(grad, partition_num, group_id, param_id) - def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: + def _add_grad( + self, + grad: torch.Tensor, + partition_num: int, + group_id: int, + param_id: int, + rank: int = 0, + ) -> None: if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: @@ -534,6 +557,7 @@ def zero_grad(self, set_to_none=True): if param.grad is not None: param.grad.detach() param.grad.zero_() + self._bucket_store.reset_all() #################### # Update Parameter # @@ -655,14 +679,20 @@ def step(self, closure=None): for _ in range(self.moe_extra_dp_pg_size) ] dist.all_gather( - all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.moe_extra_dp_pg, ) else: all_splited_param = [ torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) + dist.all_gather( + all_splited_param, + splited_param.to(device).to(self._dtype), + group=self.dp_pg, + ) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -685,7 +715,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -698,10 +730,14 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], + device=get_accelerator().get_current_device(), + dtype=torch.float, ) torch.distributed.all_reduce( - total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg + total_norm_exponentiated_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -920,5 +956,8 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: if hasattr(self, "moe_master_to_working_map"): - return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} + return { + **self._param_store.master_to_working_param, + **self.moe_master_to_working_map, + } return self._param_store.master_to_working_param diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 38b8f66a8f8b..b0352230788a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=1.12 +torch>=2.1.0 safetensors einops pydantic From 4f5fee4ec33ba39ad20644936fc392baae4cbbd4 Mon Sep 17 00:00:00 2001 From: wangbinluo <2538539015@qq.com> Date: Thu, 18 Apr 2024 09:39:21 +0000 Subject: [PATCH 20/44] fix conflicts --- colossalai/shardformer/modeling/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 53332726d10e..81ca4a25330a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -962,7 +962,7 @@ def forward( device=inputs_embeds.device, ) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length ) From b323f0a1537e203b7141e3ed70793dbc84ee2a9c Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Mon, 15 Apr 2024 18:06:18 +0800 Subject: [PATCH 21/44] [doc] fix ColossalMoE readme (#5599) * fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/ColossalMoE/README.md | Bin 6475 -> 1023 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md index ba864d1dff8b2c52b3a5c45261f37586ecdb5bc1..c3c214789f54c90a4b10be96b63619d8c4a42c6f 100644 GIT binary patch delta 7 OcmX?Y^q+mhe`Wv==mWL@ literal 6475 zcmeH`!EW0y42E~sQxM3Z8G)p7IE34}rn4Hi zm2iEn2Elbiq3TQRRE@E^TnZHdNievO zfl(i5vj*zbKD)TCVC}iKn<|OYw{{$0n4KcH`Hm zNg*_!;A^v3z}uU}!@J}zkA^;x{02E94s>V7ev3ZYd3f3c2+EB{!j?v$_d4<49^r-t z|9^*cc0ogWQ+|a&Ay5bu0);>!PzV$Pg+L)t2owT^Kp{{F6as}nAy5bu0)HdG{s14p BSs?%b From 7cecde11a599799d4a8ce632cd7afa3b5c45a568 Mon Sep 17 00:00:00 2001 From: wangbinluo <2538539015@qq.com> Date: Thu, 18 Apr 2024 10:21:19 +0000 Subject: [PATCH 22/44] merge with main --- colossalai/booster/plugin/gemini_plugin.py | 6 +- .../booster/plugin/hybrid_parallel_plugin.py | 11 +- .../hybrid_parallel_checkpoint_io.py | 29 ++- colossalai/checkpoint_io/utils.py | 9 + colossalai/shardformer/layer/__init__.py | 7 +- colossalai/shardformer/layer/embedding.py | 111 +++++++-- colossalai/shardformer/layer/linear.py | 220 +++++++++++++++++- colossalai/shardformer/layer/loss.py | 32 ++- .../shardformer/layer/parallel_module.py | 192 ++++++++++++++- colossalai/shardformer/modeling/gpt2.py | 13 +- colossalai/shardformer/modeling/llama.py | 3 +- .../shardformer/policies/base_policy.py | 9 + colossalai/shardformer/policies/bert.py | 52 +++-- colossalai/shardformer/policies/blip2.py | 62 +++-- colossalai/shardformer/policies/bloom.py | 47 ++-- colossalai/shardformer/policies/chatglm2.py | 40 ++-- colossalai/shardformer/policies/falcon.py | 49 ++-- colossalai/shardformer/policies/gpt2.py | 71 ++++-- colossalai/shardformer/policies/gptj.py | 55 +++-- colossalai/shardformer/policies/llama.py | 53 +++-- colossalai/shardformer/policies/mistral.py | 57 +++-- colossalai/shardformer/policies/opt.py | 64 +++-- colossalai/shardformer/policies/t5.py | 101 ++++++-- colossalai/shardformer/policies/whisper.py | 45 ++-- colossalai/shardformer/shard/shard_config.py | 3 +- .../tensor/d_tensor/layout_converter.py | 17 +- colossalai/tensor/padded_tensor/__init__.py | 3 + colossalai/tensor/padded_tensor/api.py | 128 ++++++++++ colossalai/testing/comparison.py | 2 +- colossalai/zero/gemini/gemini_ddp.py | 22 ++ colossalai/zero/gemini/gemini_optimizer.py | 45 +++- ...st_hybrid_parallel_plugin_checkpoint_io.py | 3 +- .../test_vocab_parallel_embedding_1d.py | 2 +- tests/test_shardformer/test_model/_utils.py | 11 +- .../test_model/test_shard_t5.py | 3 +- tests/test_tensor/test_padded_tensor.py | 46 ++++ 36 files changed, 1341 insertions(+), 282 deletions(-) create mode 100644 colossalai/tensor/padded_tensor/__init__.py create mode 100644 colossalai/tensor/padded_tensor/api.py create mode 100644 tests/test_tensor/test_padded_tensor.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6c503377326a..442ac4a8da06 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -44,10 +44,10 @@ def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. - if optim is None: return {} param_info = {"id2shape": {}} + start_index = 0 for group in optim.param_groups: for param_id, param in enumerate(group["params"], start_index): @@ -527,7 +527,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - optimizer_params_info = get_param_info(optimizer) + params_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -558,7 +558,7 @@ def configure( **self.zero_optim_config, **self.optim_kwargs, tp_group=self.tp_group, - optimizer_params_info=optimizer_params_info, + params_info=params_info, verbose=self.verbose, ) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 29cec7cfd146..8d12eb80621d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = { - "param_groups": [], - "param2id": {}, - "id2param": {}, - "param2shape": {}, - } + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} start_index = 0 for group in optim.param_groups: packed_group = {k: v for k, v in group.items() if k != "params"} @@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase): num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + """ def __init__( @@ -989,6 +986,7 @@ def __init__( num_model_chunks: int = 1, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, ) -> None: super().__init__() assert ( @@ -1095,6 +1093,7 @@ def __init__( sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, ) self.amp_config = dict( diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 80822724982e..7946d9b9c197 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,6 +14,12 @@ from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO @@ -32,6 +38,7 @@ save_param_groups, save_state_dict, save_state_dict_shards, + search_padding_dim, search_tp_partition_dim, sharded_optimizer_loading_epilogue, ) @@ -89,6 +96,8 @@ def _model_sharder( if param is None: continue # Gather tensor pieces when using tensor parallel. + if is_padded_tensor(param): + param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: @@ -231,7 +240,6 @@ def save_sharded_model( # When pipeline is used, each stage produces its own shard files and index files. # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. - final_index_file_path = copy.deepcopy(save_index_file) tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) @@ -251,6 +259,7 @@ def save_sharded_model( use_safetensors=use_safetensors, use_pp_format=True, ) + if control_saving: assert ( self.dp_rank == 0 and self.tp_rank == 0 @@ -867,6 +876,11 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=tp_group) v = torch.cat(gather_tensor, dim=partition_dim) + padding_dim = search_padding_dim(v.shape, original_shape) + if padding_dim is not None: + v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) + v = to_unpadded_tensor(v) + state_[k] = v.detach().clone().to(device) return state_ @@ -899,6 +913,19 @@ def shard_from_complete_optimizer_state( if isinstance(v, torch.Tensor) and k != "step": # Shard state along tensor parallel group. partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + global_shape = current_shape + if partition_dim is not None: + # pad embedding params + global_shape = ( + *current_shape[:partition_dim], + current_shape[partition_dim] * self.tp_size, + *current_shape[partition_dim + 1 :], + ) + + padding_dim = search_padding_dim(global_shape, original_shape) + if padding_dim is not None: + v = to_padded_tensor(v, global_shape[padding_dim], padding_dim) + if partition_dim is not None: slice_size = current_shape[partition_dim] v = v.split(slice_size, dim=partition_dim)[self.tp_rank] diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 2a1d4de9b036..6197be9d1c8d 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz return partition_dim +def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]: + padding_dim = None + for dim, length in enumerate(global_shape): + if length > original_shape[dim]: + padding_dim = dim + break + return padding_dim + + # ====================================== # Helper classes and functions for saving shard file # ====================================== diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 7b8aa53800f0..f17fad1b6606 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,8 +1,8 @@ from ._operation import all_to_all_comm from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput -from .embedding import Embedding1D, VocabParallelEmbedding1D -from .linear import Linear1D_Col, Linear1D_Row +from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .loss import cross_entropy_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule @@ -25,6 +25,9 @@ "FusedRMSNorm", "FusedLinear1D_Col", "ParallelModule", + "PaddingEmbedding", + "PaddingLMHead", + "VocabParallelLMHead1D", "AttnMaskType", "ColoAttention", "all_to_all_comm", diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index d081b204093b..cb7eceae4d25 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -21,10 +21,10 @@ ) from ._operation import gather_forward_split_backward, reduce_forward -from .parallel_module import ParallelModule +from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset -__all__ = ["Embedding1D", "VocabParallelEmbedding1D"] +__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"] class Embedding1D(ParallelModule): @@ -161,7 +161,80 @@ def forward(self, input_: Tensor) -> Tensor: return output_parallel -class VocabParallelEmbedding1D(ParallelModule): +class PaddingEmbedding(PaddingParallelModule): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + weight: Optional[nn.Parameter] = None, + make_vocab_size_divisible_by: int = 64, + *args, + **kwargs, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.embed_args = args + self.embed_kwargs = kwargs + self.padding_idx = padding_idx + if num_embeddings % make_vocab_size_divisible_by != 0: + self.num_embeddings = ( + num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by) + ) + # create weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + super().__init__(self.num_embeddings, num_embeddings, weight) + + if weight is None: + self.reset_parameters() + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + @staticmethod + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + LazyInitContext.materialize(module) + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + # create the parallel module + padding_embedding = PaddingEmbedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + weight=module.weight, + *args, + **kwargs, + ) + + return padding_embedding + + +class VocabParallelEmbedding1D(PaddingParallelModule): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -201,10 +274,10 @@ def __init__( process_group: ProcessGroup = None, weight: Optional[nn.Parameter] = None, weight_initializer: Callable = init.normal_(), + make_vocab_size_divisible_by: int = 64, *args, **kwargs, ): - super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.embed_args = args @@ -214,8 +287,23 @@ def __init__( tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group) - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings = self.num_embeddings_per_partition + # generate weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + # calculate new padding size + multiple = make_vocab_size_divisible_by * tensor_parallel_size + if num_embeddings % multiple != 0: + self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple) + + # resize vocabulary size + super().__init__(self.num_embeddings, num_embeddings, weight) + + # deal with tensor parallelism + self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size) self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition @@ -226,13 +314,6 @@ def __init__( seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - # parameter - if weight is None: - factory_kwargs = {"device": device, "dtype": dtype} - self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) - else: - weight.data = weight.data.to(device=device, dtype=dtype) - self.weight = weight if not is_distributed_tensor(self.weight): sharded_weight = shard_rowwise(self.weight.data, process_group) sharded_tensor_to_existing_param(sharded_weight, self.weight) @@ -243,7 +324,7 @@ def __init__( @staticmethod def from_native_module( module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: + ) -> PaddingParallelModule: r""" Convert a native pytorch embedding module to a parallel module. """ @@ -303,11 +384,9 @@ def forward(self, input_: Tensor) -> Tensor: # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding( masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs ) - # Mask the output embedding. embedding_output = output_parallel.clone() embedding_output[input_mask, :] = 0.0 diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 7c8619ad8f5c..37c7542416f6 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -32,7 +32,7 @@ reducescatter_forward_gather_backward, split_forward_gather_backward, ) -from .parallel_module import ParallelModule +from .parallel_module import PaddingParallelModule, ParallelModule from .utils import create_randomizer_with_offset __all__ = ["Linear1D_Col", "Linear1D_Row"] @@ -84,8 +84,9 @@ def __init__( bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, ): - super().__init__() + super().__init__(weight=weight, bias_=bias_, **kwargs) # Keep input parameters self.in_features = in_features @@ -118,6 +119,7 @@ def __init__( else: weight.data = weight.data.to(device=device, dtype=dtype) self.weight = weight + if not is_distributed_tensor(self.weight): sharded_weight = shard_rowwise(self.weight.data, self.process_group) sharded_tensor_to_existing_param(sharded_weight, self.weight) @@ -140,7 +142,7 @@ def __init__( @staticmethod def from_native_module( - module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. @@ -173,7 +175,6 @@ def from_native_module( process_group=process_group, weight=module.weight, bias_=module.bias, - *args, **kwargs, ) @@ -322,7 +323,7 @@ def __init__( @staticmethod def from_native_module( - module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. @@ -356,7 +357,6 @@ def from_native_module( process_group=process_group, weight=module.weight, bias_=module.bias, - *args, **kwargs, ) @@ -439,3 +439,211 @@ def forward(self, input_: Tensor) -> Tensor: return output else: return output, self.bias + + +class PaddingLMHead(PaddingParallelModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + + if out_features % make_vocab_size_divisible_by != 0: + self.out_features = ( + out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by) + ) + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + else: + weight.data = weight.data.to(device=device, dtype=dtype) + + if bias: + if bias_ is None: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_.data = bias_.data.to(device=device, dtype=dtype) + else: + bias_ = None + + # resize embeddings + super().__init__(self.out_features, out_features, weight, bias_) + + if weight is None: + self.reset_parameters(weight_initializer, bias_initializer) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + # ensure only one process group is passed + + lm_head_linear = PaddingLMHead( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input: Tensor) -> Tensor: + output = F.linear(input, self.weight, self.bias) + output = output[..., : self.old_num_embeddings] + return output + + +class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + make_vocab_size_divisible_by: int = 64, + **kwargs, + ): + # create weight and bias + if weight is None: + factory_kwargs = {"device": device, "dtype": dtype} + weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs)) + if bias: + if bias_ is None: + bias_ = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + bias_ = None + + # calculate new vocab size + self.tensor_parallel_size = dist.get_world_size(group=process_group) + new_out_features = out_features + multiple = make_vocab_size_divisible_by * self.tensor_parallel_size + if out_features % multiple != 0: + new_out_features = out_features + multiple - (out_features % multiple) + + super().__init__( + in_features=in_features, + out_features=new_out_features, + bias=bias, + device=device, + process_group=process_group, + weight=weight, + bias_=bias_, + **kwargs, + new_num_embeddings=new_out_features, + old_num_embeddings=out_features, + ) + # get the length of valid embeddings + tp_rank = dist.get_rank(process_group) + partition_size = self.new_num_embeddings // dist.get_world_size(process_group) + if self.old_num_embeddings >= (tp_rank + 1) * partition_size: + self.num_valid_embeddings_local = partition_size + elif self.old_num_embeddings >= tp_rank * partition_size: + self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size + else: + self.num_valid_embeddings_local = 0 + + @staticmethod + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs + ) -> PaddingParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + lm_head_linear = VocabParallelLMHead1D( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + **kwargs, + ) + + return lm_head_linear + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + # get forward output + if self.skip_bias_add: + output, bias = super().forward(input_) + else: + output = super().forward(input_) + + # delete the padding of output + if self.gather_output: + output = output[..., : self.old_num_embeddings] + else: + output = output[..., : self.num_valid_embeddings_local] + + # return + if self.skip_bias_add: + return output, bias + return output diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index c4cf3fb8517c..6d99efc19bbf 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -15,7 +15,14 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, + process_group: ProcessGroup, + vocab_size: int, + ): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -41,15 +48,21 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # mask the target in the local device - partition_vocab_size = vocab_logits.size()[-1] rank = dist.get_rank(group=process_group) world_size = dist.get_world_size(group=process_group) - global_vocab_size = partition_vocab_size * world_size + if vocab_size == None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size # [down, up) => false, other device and -100 => true delta = (global_vocab_size + world_size - 1) // world_size down_threshold = rank * delta up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size mask = (target < down_threshold) | (target >= up_threshold) masked_target = target.clone() - down_threshold masked_target[mask] = 0 @@ -57,7 +70,8 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: # reshape the logits and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] - logits_2d = vocab_logits.view(-1, partition_vocab_size) + self_vocab_size = vocab_logits.size()[-1] + logits_2d = vocab_logits.view(-1, self_vocab_size) masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero @@ -104,10 +118,14 @@ def backward(ctx, grad_output): grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None + return grad_logits, None, None, None, None def cross_entropy_1d( - vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None + vocab_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, + process_group: ProcessGroup = None, + vocab_size: int = None, ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size) diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 6c0d83cc7a20..11ef73538c36 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -3,7 +3,7 @@ import itertools from abc import ABC, abstractmethod -from typing import List, Union +from typing import List, Optional, Union import torch import torch.nn as nn @@ -20,11 +20,15 @@ is_distributed_tensor, sharded_tensor_to_param, ) +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor __all__ = ["ParallelModule"] class ParallelModule(nn.Module, ABC): + def __init__(self, **kwargs): + super().__init__() + @abstractmethod def from_native_module( module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None @@ -54,7 +58,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: @@ -171,3 +175,187 @@ def _load_from_state_dict( input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) + + +class PaddingParallelModule(ParallelModule): + def __init__( + self, + new_num_embeddings: int, + old_num_embeddings: int, + weight: Optional[nn.Parameter], + bias_: Optional[nn.Parameter] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.new_num_embeddings = new_num_embeddings + self.old_num_embeddings = old_num_embeddings + self.weight = weight + self.bias = bias_ + + if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings): + self.resize_embedding_weight() + + if self.bias is not None and not ( + is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings + ): + self.resize_embedding_bias() + + @abstractmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "PaddingParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + raise NotImplementedError + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + param = gather_distributed_param(param, keep_vars=keep_vars) + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + destination[prefix + name] = param.data + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append( + 'While copying the parameter named "{}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + "received {}".format(key, type(input_param)) + ) + continue + + if is_padded_tensor(param): + input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim) + + if is_distributed_tensor(param): + # shard the input param + device_mesh = get_device_mesh(param) + sharding_spec = get_sharding_spec(param) + sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) + input_param = sharded_tensor_to_param(sharded_tensor) + elif is_customized_distributed_tensor(param): + input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn) + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) + continue + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def resize_embedding_weight(self): + self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0) + + def resize_embedding_bias(self): + self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 26c48c5e2d31..17acdf7fcbba 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -26,7 +26,6 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import gather_forward_split_backward logger = logging.get_logger(__name__) @@ -389,13 +388,11 @@ def gpt2_lmhead_model_forward( shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) else: loss = loss_fct(shift_logits, shift_labels) - if not shard_config.parallel_output: - lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) - if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1293,12 +1290,12 @@ def forward( shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1) loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) - if not shard_config.parallel_output: - lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) - if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 81ca4a25330a..0eb08a0432e7 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -329,6 +329,7 @@ def llama_for_causal_lm_forward( shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) @@ -758,13 +759,13 @@ def forward( shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, ) if not return_dict: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index d67ab0a3c6bb..e976672bbfd2 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -195,3 +195,12 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] """ return [] + + def tie_weight_check(self): + input_embedding = self.model.get_input_embeddings() + output_embedding = self.model.get_output_embeddings() + return ( + input_embedding is not None + and output_embedding is not None + and id(input_embedding.weight) == id(output_embedding.weight) + ) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0a61d8cff410..d43fc893aedc 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -37,17 +37,7 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - # TODO: - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -62,6 +52,13 @@ def module_policy(self): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -150,10 +147,6 @@ def module_policy(self): policy[BertEmbeddings] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForReplicatedInput, @@ -168,6 +161,18 @@ def module_policy(self): target_key=BertModel, ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=embedding_cls, + ) + ], + policy=policy, + target_key=BertEmbeddings, + ) + # optimization configuration # Handle bert layer self.append_or_create_submodule_replacement( @@ -237,8 +242,21 @@ def add_lm_head_policy(self, base_policy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="decoder", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="decoder", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=base_policy, target_key=BertLMPredictionHead, diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 9be2a1e78073..b845e9336cac 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -17,16 +17,7 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - # TODO: - vocab_size = self.model.config.qformer_config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -43,6 +34,13 @@ def module_policy(self): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -202,22 +200,48 @@ def module_policy(self): ], ) - policy[OPTForCausalLM] = ModulePolicyDescription( - sub_module_replacement=[ + policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="model.decoder.embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), + ], + policy=policy, + target_key=OPTForCausalLM, + ) + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ), - ] + ], + policy=policy, + target_key=OPTForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=OPTForCausalLM, ) - - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) - # optimization configuration # Handle Blip2EncoderLayer layer self.append_or_create_submodule_replacement( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 956fd95d2341..4894bda35bfc 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -29,16 +29,7 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -46,6 +37,13 @@ def module_policy(self): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -106,12 +104,19 @@ def module_policy(self): method_replacement={ "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, - sub_module_replacement=[ + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), ], + policy=policy, + target_key=BloomModel, ) # optimization configuration @@ -276,7 +281,21 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), + ), + policy=policy, + target_key=BloomForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ), policy=policy, target_key=BloomForCausalLM, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index dabc14bffc95..f205835e7815 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -25,20 +25,12 @@ def config_sanity_check(self): pass def preprocess(self): - # Resize embedding - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.padded_vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - if self.pipeline_stage_manager is not None: # the batch_size_dim is bounded to Model bsz_dim = 1 setattr(self.model, "batch_size_dim", bsz_dim) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -46,6 +38,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: if self.model.config.rmsnorm: norm_cls = col_nn.FusedRMSNorm @@ -68,16 +67,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription( - attribute_replacement={}, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embedding.word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) - ], - ) - policy[GLMBlock] = ModulePolicyDescription( attribute_replacement={ "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads @@ -114,6 +103,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ], ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + ], + policy=policy, + target_key=ChatGLMModel, + ) # optimization configuration self.append_or_create_submodule_replacement( description=[ diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index a0dfcf7902e9..628e9fdc0d96 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -26,16 +26,7 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -52,6 +43,14 @@ def module_policy(self): warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") policy = {} + + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_tensor_parallelism: attn_attribute_replacement = { "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -92,12 +91,19 @@ def module_policy(self): method_replacement={ "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) }, - sub_module_replacement=[ + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), ], + policy=policy, + target_key=FalconModel, ) # optimization configuration @@ -226,11 +232,26 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=col_nn.VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), ), policy=policy, target_key=FalconForCausalLM, ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), + ), + policy=policy, + target_key=FalconForCausalLM, + ) + if self.pipeline_stage_manager: self.set_pipeline_forward( model_cls=FalconForCausalLM, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 380a432dc8b8..98db7b948954 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -34,12 +34,7 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -47,6 +42,13 @@ def module_policy(self): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -73,10 +75,6 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="drop", target_module=col_nn.DropoutForParallelInput, @@ -137,6 +135,17 @@ def module_policy(self): ), ], ) + if embedding_cls is not None: + # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="wte", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=GPT2Model, + ) # optimization configuration self.append_or_create_submodule_replacement( @@ -298,8 +307,11 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": False, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ], ) @@ -308,7 +320,19 @@ def module_policy(self): addon_module[GPT2LMHeadModel].method_replacement = { "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) } - module_policy.update(addon_module) + else: + addon_module = { + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + module_policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( @@ -353,13 +377,28 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } - module_policy.update(addon_module) + else: + addon_module = { + GPT2DoubleHeadsModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ] + ) + } + module_policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index eab4c214a41f..4b69137a6892 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -29,22 +29,21 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel policy = {} + + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") @@ -54,10 +53,6 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: policy[GPTJModel] = ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), SubModuleReplacementDescription( suffix="drop", target_module=col_nn.DropoutForParallelInput, @@ -126,6 +121,17 @@ def module_policy(self): ], ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="wte", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=GPTJModel, + ) + # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement( @@ -255,13 +261,28 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ) + ] + ) + } + else: + addon_module = { + GPTJForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ) ] ) } - policy.update(addon_module) + policy.update(addon_module) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index bb4551b2c31c..ff686a179553 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -6,7 +6,16 @@ from torch import Tensor from torch.nn import Module -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + RMSNorm, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from ..modeling.llama import ( LlamaPipelineForwards, @@ -26,15 +35,7 @@ def config_sanity_check(self): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -42,6 +43,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedRMSNorm else: @@ -167,10 +175,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ], ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=LlamaModel, @@ -327,8 +337,11 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs={"gather_output": not self.shard_config.parallel_output}, + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ], ) @@ -337,7 +350,19 @@ def module_policy(self): new_item[LlamaForCausalLM].method_replacement = { "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) } - policy.update(new_item) + else: + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ) + ], + ) + } + policy.update(new_item) if self.pipeline_stage_manager: # set None as default diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 3645cf3694fa..ce4864dac6ad 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -4,7 +4,15 @@ import torch.nn as nn -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from ..modeling.mistral import MistralForwards, get_mistral_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -17,15 +25,7 @@ def config_sanity_check(self): pass def preprocess(self): - if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -33,6 +33,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( @@ -81,10 +88,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ], ) + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=MistralModel, @@ -152,6 +161,8 @@ def module_policy(self): from transformers import MistralForCausalLM policy = super().module_policy() + if self.pipeline_stage_manager: + warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -159,16 +170,30 @@ def module_policy(self): MistralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + ), + ) + ] + ) + } + else: + new_item = { + MistralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ) ] ) } - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") - - policy.update(new_item) + policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5c753e0af5aa..2bb28b095114 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -5,7 +5,16 @@ import torch.nn as nn from torch import Tensor, nn -from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import ( + FusedLayerNorm, + LayerNorm, + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func @@ -34,16 +43,7 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -51,6 +51,13 @@ def module_policy(self): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedLayerNorm else: @@ -61,14 +68,6 @@ def module_policy(self): warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ] - ) policy[OPTDecoderLayer] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( @@ -107,6 +106,17 @@ def module_policy(self): ], ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=OPTDecoder, + ) + # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -246,8 +256,20 @@ def module_policy(self): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + ), + ), + policy=policy, + target_key=OPTForCausalLM, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), ), policy=policy, target_key=OPTForCausalLM, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 0c8ec15fa0a9..3c7e92b47db0 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -13,8 +13,11 @@ FusedRMSNorm, Linear1D_Col, Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription @@ -36,16 +39,7 @@ def config_sanity_check(self): pass def preprocess(self): - # reshape the embedding layer - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - if self.shard_config.enable_tensor_parallelism: - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -61,6 +55,13 @@ def module_policy(self): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = FusedRMSNorm else: @@ -77,10 +78,6 @@ def module_policy(self): suffix="dropout", target_module=DropoutForParallelInput, ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ), ] ) policy[T5LayerSelfAttention] = ModulePolicyDescription( @@ -176,6 +173,17 @@ def module_policy(self): ] ) + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=T5Stack, + ) + # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( @@ -370,11 +378,19 @@ def module_policy(self): policy = super().module_policy() + embedding_cls = None if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="shared", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=T5Model, @@ -406,17 +422,44 @@ def module_policy(self): policy = super().module_policy() + embedding_cls = None if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ), - ], + description=SubModuleReplacementDescription( + suffix="shared", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=T5ForConditionalGeneration, + ) + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=policy, + target_key=T5ForConditionalGeneration, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), policy=policy, target_key=T5ForConditionalGeneration, ) @@ -467,11 +510,19 @@ def module_policy(self): policy = super().module_policy() + embedding_cls = None if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + if embedding_cls is not None: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="shared", - target_module=VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=policy, target_key=T5EncoderModel, diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 5a7021a72294..16ed2607c6f7 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -38,11 +38,7 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) + self.tie_weight = self.tie_weight_check() return self.model def module_policy(self): @@ -56,6 +52,13 @@ def module_policy(self): policy = {} + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = col_nn.VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = col_nn.PaddingEmbedding + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -160,13 +163,17 @@ def module_policy(self): ], ) - policy[WhisperDecoder] = ModulePolicyDescription( - sub_module_replacement=[ + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=[ SubModuleReplacementDescription( suffix="embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), - ] + ], + policy=policy, + target_key=WhisperDecoder, ) # optimization configuration @@ -273,8 +280,21 @@ def add_lm_head_policy(self, base_policy): self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="proj_out", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, + target_module=col_nn.VocabParallelLMHead1D, + kwargs={ + "gather_output": True, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, + ), + policy=base_policy, + target_key=WhisperForConditionalGeneration, + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="proj_out", + target_module=col_nn.PaddingLMHead, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, ), policy=base_policy, target_key=WhisperForConditionalGeneration, @@ -519,9 +539,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): - def preprocess(self): - return self.model - def module_policy(self): from transformers import WhisperForAudioClassification diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 7489873c2ed6..963732543f27 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -42,10 +42,9 @@ class ShardConfig: sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True + make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) - # TODO padding vocab - # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 667a7b78e4f5..c2cf73181345 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -10,6 +10,7 @@ from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.misc import LayoutException +from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from .sharding_spec import ShardingSpec @@ -607,8 +608,18 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo [3.], [3.]]) """ + _, comm_action_sequence = self.layout_converting(source_layout, target_layout) + + target_tensor = tensor for comm_spec in comm_action_sequence: - tensor = comm_spec.covert_spec_to_action(tensor) - tensor.dist_layout = target_layout - return tensor + target_tensor = comm_spec.covert_spec_to_action(target_tensor) + target_tensor.dist_layout = target_layout + + # restore the padding information + if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor): + target_tensor = init_as_padded_tensor( + target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim + ) + + return target_tensor diff --git a/colossalai/tensor/padded_tensor/__init__.py b/colossalai/tensor/padded_tensor/__init__.py new file mode 100644 index 000000000000..353ff35f84ca --- /dev/null +++ b/colossalai/tensor/padded_tensor/__init__.py @@ -0,0 +1,3 @@ +from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor + +__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"] diff --git a/colossalai/tensor/padded_tensor/api.py b/colossalai/tensor/padded_tensor/api.py new file mode 100644 index 000000000000..5b66c016b399 --- /dev/null +++ b/colossalai/tensor/padded_tensor/api.py @@ -0,0 +1,128 @@ +import torch + + +def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor._unpad_detach = ptensor.detach + ptensor._unpad_clone = ptensor.clone + + def new_detach(self): + t_ = self._unpad_detach() + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._unpad_clone(*args, **kwargs) + t_._padding_dim = self._padding_dim + t_._origin_length = self._origin_length + t_._current_length = self._current_length + return t_ + + # bind the new methods to the tensor + ptensor.detach = new_detach.__get__(ptensor) + ptensor.clone = new_clone.__get__(ptensor) + return ptensor + + +def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor.detach = ptensor._unpad_detach + ptensor.clone = ptensor._unpad_clone + + delattr(ptensor, "_unpad_detach") + delattr(ptensor, "_unpad_clone") + + return ptensor + + +def is_padded_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a padding tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a padding tensor. + """ + return hasattr(tensor, "_padding_dim") + + +def to_padded_tensor( + tensor: torch.Tensor, + current_length: int, + padding_dim: int, +) -> torch.Tensor: + assert ( + padding_dim < tensor.dim() + ), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}" + + if is_padded_tensor(tensor): + return tensor + + origin_length = tensor.shape[padding_dim] + padding_num = current_length - origin_length + padding_data = torch.zeros( + *tensor.shape[:padding_dim], + padding_num, + *tensor.shape[padding_dim + 1 :], + device=tensor.device, + dtype=tensor.dtype, + ) + tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous() + + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length + + _hijack_detach_and_clone(tensor) + + return tensor + + +def to_unpadded_tensor(ptensor: torch.Tensor): + if not is_padded_tensor(ptensor): + return ptensor + + unpad_slices = [slice(None)] * ptensor.dim() + unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length) + ptensor.data = ptensor.data[tuple(unpad_slices)] + + delattr(ptensor, "_padding_dim") + delattr(ptensor, "_origin_length") + delattr(ptensor, "_current_length") + + _hijack_back_detach_and_clone(ptensor) + + return ptensor + + +def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int): + if is_padded_tensor(tensor): + return tensor + + tensor._padding_dim = padding_dim + tensor._origin_length = origin_length + tensor._current_length = current_length + + _hijack_detach_and_clone(tensor) + + return tensor diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e415b5fc3aa3..bdf7b19f39d0 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1 rtol=rtol, atol=atol, msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ - dtype: {a.dtype} vs {b.dtype}", + dtype: {a.dtype} vs {b.dtype}", ) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index bc6c9d088094..c79422171f1b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -27,6 +27,12 @@ is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float, free_storage, is_ddp_ignored @@ -460,6 +466,11 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn ) record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() + if is_padded_tensor(tensor): + record_tensor = init_as_padded_tensor( + record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim + ) + record_tensor = to_unpadded_tensor(record_tensor) assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -520,6 +531,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): # deal with ddp ignored parameters destination[prefix + name] = param if keep_vars else param.detach() else: + if is_padded_tensor(p_mapping[param]): + p_mapping[param] = to_unpadded_tensor(p_mapping[param]) destination[prefix + name] = p_mapping[param] del p_mapping del param_to_save_data @@ -627,6 +640,7 @@ def _load_from_state_dict( list, and will be reported together in :meth:`~torch.nn.Module.load_state_dict` """ + for hook in self._load_state_dict_pre_hooks.values(): hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -647,6 +661,14 @@ def load( if state_key in state_dict: input_param = state_dict[state_key] + global_shape = dest_tensor.shape + if source_device_mesh is not None and source_sharding_spec is not None: + global_shape = get_global_shape(dest_tensor) + + if is_padded_tensor(dest_tensor): + padding_dim = dest_tensor._padding_dim + input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim) + if source_device_mesh is not None and source_sharding_spec is not None: input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) elif shard_fn is not None and gather_fn is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 18367af59d80..ae02fe297d88 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -21,12 +21,19 @@ distribute_tensor, distribute_tensor_with_customization, get_device_mesh, + get_global_shape, get_sharding_spec, init_as_dtensor, init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.padded_tensor import ( + init_as_padded_tensor, + is_padded_tensor, + to_padded_tensor, + to_unpadded_tensor, +) from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager @@ -106,7 +113,7 @@ def __init__( max_norm: float = 0.0, norm_type: float = 2.0, tp_group: ProcessGroup = None, - optimizer_params_info=None, + params_info=None, verbose: bool = False, **defaults: Any, ): @@ -124,7 +131,7 @@ def __init__( self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm self.tp_group = tp_group - self.optimizer_params_info = optimizer_params_info + self.params_info = params_info self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose @@ -459,7 +466,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: is_customized_distributed = is_customized_distributed_tensor(param) shard_spec = get_sharding_spec(param) if is_dtensor else None device_mesh = get_device_mesh(param) if is_dtensor else None - global_shape = self.optimizer_params_info["id2shape"][param_id] + global_shape = self.params_info["id2shape"][param_id] # If the chunk is kept gathered, # the parameters are treated the same as that of those in strict DDP during training. @@ -477,6 +484,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: else: state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() if is_dtensor: + global_shape = get_global_shape(param) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = init_as_dtensor( state_tensor, @@ -490,8 +498,13 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - - collected_states[state_name] = state_tensor.reshape(global_shape) + state_tensor = state_tensor.reshape(global_shape) + if is_padded_tensor(param): + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) + collected_states[state_name] = state_tensor return collected_states # Check whether the param with given id is managed by current process. @@ -535,6 +548,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: if state_tensor.numel() == param.numel(): collected_states[state_name] = torch.reshape(state_tensor, param.shape) if is_dtensor: + global_shape = get_global_shape(param) state_tensor = state_tensor.to(param.device) state_tensor = init_as_dtensor( state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape @@ -545,6 +559,11 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() + if is_padded_tensor(param): + state_tensor = init_as_padded_tensor( + state_tensor, param._current_length, param._origin_length, param._padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) return collected_states @@ -698,7 +717,7 @@ def load_single_param_states(self, param_id: int, saved_states: dict): Load saved optimizer states into parameter with given id. """ - def cast(param, state_range, value, key=None): + def cast(param, state_range, value, global_shape, origin_shape, key=None): """ Make a copy of the needed segment of value and cast it to device of param. """ @@ -714,7 +733,14 @@ def cast(param, state_range, value, key=None): ) if is_dtensor: - value = torch.reshape(value, global_shape) + global_shape = get_global_shape(real_param) + + if is_padded_tensor(real_param): + value = torch.reshape(value, origin_shape) + padding_dim = real_param._padding_dim + value = to_padded_tensor(value, global_shape[padding_dim], padding_dim) + + if is_dtensor: value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) elif is_customized_distributed: value = torch.reshape(value, global_shape) @@ -737,10 +763,11 @@ def cast(param, state_range, value, key=None): is_customized_distributed = is_customized_distributed_tensor(real_param) shard_spec = get_sharding_spec(real_param) if is_dtensor else None device_mesh = get_device_mesh(real_param) if is_dtensor else None - global_shape = self.optimizer_params_info["id2shape"][param_id] + global_shape = self.params_info["id2shape"][param_id] + origin_shape = global_shape for k, v in saved_states.items(): - updated_states[k] = cast(fake_param, state_range, v, k) + updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k) del v # clean loaded states self.optim.state[fake_param].update(updated_states) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index d8a625b98a66..4753ab637f01 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -81,8 +81,7 @@ def _preprocess_data(data): optimizer.backward(loss) optimizer.step() - for group in optimizer.param_groups: - group["lr"] = 0.1 + optimizer.zero_grad() with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index b23a44f2dffa..91cc1a987a29 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool): dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) - assert dist_embedding_1d.num_embeddings == 64 + assert dist_embedding_1d.num_embeddings == 128 assert dist_embedding_1d.embedding_dim == 32 assert embedding_copy.weight is dist_embedding_1d.weight diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d5fc2c30f294..a77ba39a122c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -14,12 +14,14 @@ from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule +from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor def build_model( @@ -247,11 +249,10 @@ def check_weight( continue if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): - sharded_weight_list = [ - torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) - ] - dist.all_gather(sharded_weight_list, sharded_weight, tp_group) - sharded_weight = torch.cat(sharded_weight_list, dim=dim) + sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False) + + if is_padded_tensor(sharded_weight): + sharded_weight = to_unpadded_tensor(sharded_weight) if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 9b22d54d7d31..a6fe2dd39383 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if test_config["precision"] == "fp32": - atol, rtol = 5e-4, 1e-3 + # TODO he precision in weight checking is too significant. + atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): diff --git a/tests/test_tensor/test_padded_tensor.py b/tests/test_tensor/test_padded_tensor.py new file mode 100644 index 000000000000..31a267c15286 --- /dev/null +++ b/tests/test_tensor/test_padded_tensor.py @@ -0,0 +1,46 @@ +import torch + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global +from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_padded_tensor(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + original_tensor = torch.rand(32, 64).to("cuda") + + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) + + padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0) + assert padded_tensor.dist_layout == d_tensor.dist_layout + + tensor_copy = padded_tensor.clone() + assert is_padded_tensor(tensor_copy) + assert is_distributed_tensor(tensor_copy) + + tensor_detached = padded_tensor.detach() + assert is_padded_tensor(tensor_detached) + assert is_distributed_tensor(tensor_detached) + + unpadded_tensor = to_unpadded_tensor(padded_tensor) + assert unpadded_tensor.shape == d_tensor.shape + assert is_distributed_tensor(unpadded_tensor) + + global_tensor = to_global(unpadded_tensor) + assert global_tensor.shape == original_tensor.shape + + +@rerun_if_address_is_in_use() +def test_padded_tensor(): + world_size = 4 + spawn(check_padded_tensor, world_size) + + +if __name__ == "__main__": + test_padded_tensor() From 98eff6d8211a16012a03fb6c987c3b7f4ddeb2c1 Mon Sep 17 00:00:00 2001 From: wangbinluo <2538539015@qq.com> Date: Thu, 18 Apr 2024 10:24:15 +0000 Subject: [PATCH 23/44] merge with main --- colossalai/shardformer/modeling/llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 0eb08a0432e7..a5b07a4cfe92 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -23,6 +23,7 @@ repeat_kv, ) from transformers.utils import logging +from transformers.cache_utils import Cache from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( @@ -514,8 +515,8 @@ def forward( raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) + "with a layer index." + ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) From 267efc8b221b40557368c2ea524e4a5fcc6dfa72 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 10:36:25 +0800 Subject: [PATCH 24/44] llama_model_forward --- colossalai/shardformer/modeling/llama.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a5b07a4cfe92..fab2ebcb1eed 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -34,7 +34,12 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d +from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +<<<<<<< HEAD +======= + +>>>>>>> llama_model_forward class LlamaPipelineForwards: """ @@ -114,6 +119,10 @@ def llama_model_forward( device=device, ) position_ids = position_ids.unsqueeze(0) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage From 0bdcc840dfb89f2a60b6df802e88f17e0cf7224d Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Mon, 25 Mar 2024 14:24:57 +0800 Subject: [PATCH 25/44] remove useless comment --- colossalai/shardformer/modeling/llama.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index fab2ebcb1eed..64ead8e847fd 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -36,10 +36,6 @@ from ..layer import ColoAttention, cross_entropy_1d from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -<<<<<<< HEAD -======= - ->>>>>>> llama_model_forward class LlamaPipelineForwards: """ From e520e0b8d702a63e3a6926b25532e346f219062e Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Tue, 26 Mar 2024 15:50:36 +0800 Subject: [PATCH 26/44] remove the LATEST VERSION try --- colossalai/shardformer/modeling/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 64ead8e847fd..1e4ccb85f1e9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -34,6 +34,7 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d +from ..layer._operation import gather_forward_split_backward from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa From 2d9a21d6f9af6ad5ddd900a057ced775762e9f9a Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Mon, 1 Apr 2024 15:36:27 +0800 Subject: [PATCH 27/44] [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction --- colossalai/shardformer/modeling/bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index c4f326364596..bec6d4ab81b9 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -22,7 +22,7 @@ BloomModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig From 50b4c869528cb60b2c3c6d342b5f05abac2fe6a7 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:06:57 +0800 Subject: [PATCH 28/44] [shardformer] update mistral model (#5511) --- colossalai/shardformer/modeling/mistral.py | 16 +++++++++------- colossalai/shardformer/policies/mistral.py | 8 +++++++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 3b876bcab96a..d96da3705e77 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -2,19 +2,21 @@ from typing import List, Optional, Tuple, Union import torch -from transformers.cache_utils import Cache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast +from typing import List, Optional, Tuple, Union +import warnings from transformers.models.mistral.modeling_mistral import MistralModel +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.utils import logging +from transformers.cache_utils import Cache logger = logging.get_logger(__name__) - class MistralForwards: + @staticmethod def mistral_model_forward( - self: MistralModel, + self:MistralModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -93,6 +95,7 @@ def mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -121,7 +124,7 @@ def mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -143,7 +146,6 @@ def mistral_model_forward( attentions=all_self_attns, ) - def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv @@ -217,7 +219,7 @@ def forward( if not output_attentions: attn_weights = None - + return attn_output, attn_weights, past_key_value return forward diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index ce4864dac6ad..ccc66f52cdf2 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Callable, Dict, Union +from typing import Dict, Union, Callable import torch.nn as nn @@ -138,6 +138,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model + + def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + method_replacement = { + "forward": partial(new_forward) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: method_replacement = {"forward": partial(new_forward)} From 1233fc22c03574fbdc7ea5c4841a22720bc8999a Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:23:03 +0800 Subject: [PATCH 29/44] [shardformer] update opt (#5522) --- colossalai/shardformer/modeling/opt.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index de5b1a267cd7..71f8ce246a0a 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,6 +1,5 @@ import random from typing import List, Optional, Tuple, Union - import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -17,7 +16,7 @@ OPTModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -57,6 +56,20 @@ class OPTPipelineForwards: This class serves as a micro library for forward function substitution of OPT models under pipeline setting. """ + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): From ab160a8b5d434978a2149d59046c450ae282c888 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Apr 2024 12:25:22 +0800 Subject: [PATCH 30/44] [shardformer] update whisper model (#5529) --- colossalai/shardformer/modeling/whisper.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 509fc3dac86f..7ae701aee83a 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -25,7 +25,7 @@ shift_tokens_right, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -699,6 +699,20 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) if self._use_flash_attention_2: # 2d mask is passed through the layers From 16a29ff6c09aafb9a6c77ab01c3397921d09b106 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 12 Apr 2024 13:14:14 +0800 Subject: [PATCH 31/44] [shardformer] fix llama modeling --- colossalai/shardformer/modeling/llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1e4ccb85f1e9..23085ddc158d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -23,7 +23,6 @@ repeat_kv, ) from transformers.utils import logging -from transformers.cache_utils import Cache from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import ( @@ -521,8 +520,8 @@ def forward( raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) + "with a layer index." + ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) From 06d7c30033e1a86973d155709f635d2036e65610 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 05:15:30 +0000 Subject: [PATCH 32/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/bloom.py | 2 +- colossalai/shardformer/modeling/mistral.py | 16 +++++++--------- colossalai/shardformer/modeling/opt.py | 5 +++-- colossalai/shardformer/modeling/whisper.py | 4 ++-- colossalai/shardformer/policies/mistral.py | 8 +++----- 5 files changed, 16 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index bec6d4ab81b9..c4f326364596 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -22,7 +22,7 @@ BloomModel, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index d96da3705e77..3b876bcab96a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -2,21 +2,19 @@ from typing import List, Optional, Tuple, Union import torch +from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from typing import List, Optional, Tuple, Union -import warnings from transformers.models.mistral.modeling_mistral import MistralModel -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.utils import logging -from transformers.cache_utils import Cache logger = logging.get_logger(__name__) + class MistralForwards: - @staticmethod def mistral_model_forward( - self:MistralModel, + self: MistralModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -95,7 +93,6 @@ def mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -124,7 +121,7 @@ def mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -146,6 +143,7 @@ def mistral_model_forward( attentions=all_self_attns, ) + def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv @@ -219,7 +217,7 @@ def forward( if not output_attentions: attn_weights = None - + return attn_output, attn_weights, past_key_value return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 71f8ce246a0a..dd3afecd6723 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,5 +1,6 @@ import random from typing import List, Optional, Tuple, Union + import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -16,7 +17,7 @@ OPTModel, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -56,7 +57,7 @@ class OPTPipelineForwards: This class serves as a micro library for forward function substitution of OPT models under pipeline setting. """ - + @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 7ae701aee83a..ae1772a66848 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -25,7 +25,7 @@ shift_tokens_right, ) from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa + from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig @@ -699,7 +699,7 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index ccc66f52cdf2..61e1b5f9c7b4 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Dict, Union, Callable +from typing import Callable, Dict, Union import torch.nn as nn @@ -138,11 +138,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model - + def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = { - "forward": partial(new_forward) - } + method_replacement = {"forward": partial(new_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: From b427fee2ce0a4674c5a110e88ada0d09a9460b86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 10:35:33 +0000 Subject: [PATCH 33/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 23085ddc158d..ac9baad5fdb9 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -34,7 +34,6 @@ from ..layer import ColoAttention, cross_entropy_1d from ..layer._operation import gather_forward_split_backward -from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa class LlamaPipelineForwards: @@ -115,7 +114,7 @@ def llama_model_forward( device=device, ) position_ids = position_ids.unsqueeze(0) - + if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None From 0b2584dd8c49c9ca0aa6bd3a3ffa7f7c535b23c3 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 18 Apr 2024 18:15:50 +0800 Subject: [PATCH 34/44] [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) * fix no pad token bug * fixed some auto parallel codegen bug, but might not run on torch 2.1 --------- Co-authored-by: Edenzzzz --- colossalai/_analyzer/fx/codegen.py | 2 +- colossalai/auto_parallel/offload/base_offload_module.py | 2 +- colossalai/auto_parallel/offload/region.py | 3 ++- colossalai/autochunk/autochunk_codegen.py | 2 +- colossalai/fx/codegen/activation_checkpoint_codegen.py | 2 +- examples/language/gpt/hybridparallelism/data.py | 2 ++ 6 files changed, 8 insertions(+), 5 deletions(-) diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index cd244b22cac0..68a27d91986b 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -246,7 +246,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, @compatibility(is_backward_compatible=True) class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index f5e8e31f5e97..60de7743a52e 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -5,7 +5,7 @@ import torch.nn as nn from colossalai.utils import _cast_float -from colossalai.zero.legacy.gemini.tensor_utils import free_storage +from colossalai.utils.common import free_storage from .region_manager import RegionManager from .util import GlobalRuntimeInfo diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index ea92c714ce31..a9f6f4c180de 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -3,7 +3,8 @@ import torch from torch.fx import Node -from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage +from colossalai.utils.common import free_storage +from colossalai.zero.gemini.chunk.chunk import alloc_storage class Region: diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 9571fa2c17f0..07dbf8a79fb6 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -372,7 +372,7 @@ def __init__( if print_progress: get_logger().info("AutoChunk start codegen") - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index dfb5754d71c1..28451bdd1870 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -625,7 +625,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if CODEGEN_AVAILABLE: class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] globals_: Dict[str, Any] = {} diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py index ef51f938dc4f..e5dc882bc097 100644 --- a/examples/language/gpt/hybridparallelism/data.py +++ b/examples/language/gpt/hybridparallelism/data.py @@ -62,6 +62,8 @@ def __init__( self.text_fields = self.task_text_field_map[task_name] self.num_labels = self.glue_task_num_labels[task_name] self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + if not getattr(self.tokenizer, "pad_token", None): + self.tokenizer.pad_token = self.tokenizer._eos_token self.setup() def setup(self): From cbea063b0efecf18ce8b464e81117bcc7088f371 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 22 Apr 2024 11:25:39 +0800 Subject: [PATCH 35/44] [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt --- colossalai/shardformer/shard/grad_ckpt_config.py | 6 ++++++ colossalai/shardformer/shard/shard_config.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 9c6c2b54ea39..9fc857d19dbc 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. """ + """ Args: gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. @@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): num_stages: Optional[int] = None num_model_chunks: Optional[int] = None num_model_layers: Optional[int] = None + num_layers_per_stage: Optional[List[int]] = None num_ckpt_layers_per_stage: Optional[List[int]] = None def __post_init__(self): @@ -70,6 +72,10 @@ def __post_init__(self): def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None + @property + def _customize_num_layers_per_stage(self) -> bool: + return self.num_layers_per_stage is not None and self.num_model_layers is not None + @property def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 963732543f27..597dd9c26354 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -7,7 +7,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager -from .grad_ckpt_config import GradientCheckpointConfig +from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig __all__ = ["ShardConfig"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @@ -30,6 +30,7 @@ class ShardConfig: gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ + tensor_parallel_process_group: Optional[ProcessGroup] = None sequence_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -104,6 +105,16 @@ def __post_init__(self): else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + if ( + self.pipeline_stage_manager is not None + and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig) + and self.gradient_checkpoint_config._customize_num_layers_per_stage + ): + self.pipeline_stage_manager.set_distribution_config( + self.gradient_checkpoint_config.num_model_layers, + self.gradient_checkpoint_config.num_layers_per_stage, + ) + def _turn_on_all_optimization(self): """ Turn on all optimization. From 46190f498926b93753292fd03ba3e75ff62464b3 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 23 Apr 2024 19:13:32 +0800 Subject: [PATCH 36/44] [shardformer] fix whisper (#5628) --- colossalai/shardformer/modeling/whisper.py | 61 ++++++------------- tests/kit/model_zoo/transformers/whisper.py | 1 + .../test_model/test_shard_whisper.py | 2 +- 3 files changed, 21 insertions(+), 43 deletions(-) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index ae1772a66848..6d7df963a3a0 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -39,6 +39,8 @@ def _get_attention_mask( hidden_states: torch.Tensor, past_key_values_length: int, attention_mask: Optional[torch.FloatTensor], + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ): batch_size, seq_length = hidden_states.shape[:2] mask_seq_length = past_key_values_length + seq_length @@ -51,12 +53,20 @@ def _get_attention_mask( is_causal=True, ) else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - hidden_states, - past_key_values_length, - ) + input_shape = (batch_size, seq_length) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and head_mask is None and not output_attentions: + # output_attentions=True & head_mask can not be supported when using SDPA. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, input_shape, hidden_states, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) return attention_mask @@ -700,33 +710,9 @@ def whisper_decoder_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + attention_mask = _get_attention_mask( + self, shard_config, inputs_embeds, past_key_values_length, attention_mask + ) # embed positions if input_ids is not None: @@ -734,14 +720,6 @@ def whisper_decoder_forward( else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = _get_attention_mask( - self, - shard_config, - inputs_embeds, - past_key_values_length, - attention_mask, - ) - hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -758,7 +736,6 @@ def whisper_decoder_forward( "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." ) input_shape = hidden_states.size()[:-1] - attention_mask = _get_attention_mask( self, shard_config, diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index d69bebe6cc04..0d9a581dfbe9 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -66,6 +66,7 @@ def data_gen_for_audio_classification(): encoder_ffn_dim=1536, encoder_layers=2, vocab_size=51866, + _attn_implementation="eager", ) # register the Whisper variants diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 6efb8a922f85..af61e464014f 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 2, "enable_metadata_cache": False, "enable_all_optimization": True, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, }, From 4a0b2de5d2ba6a2b75282ca3a4dc919777c38893 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 23 Apr 2024 23:25:11 +0800 Subject: [PATCH 37/44] [test] fix llama model test --- tests/kit/model_zoo/transformers/llama.py | 1 + tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 58b5b0487a82..08c05e9063bf 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -65,6 +65,7 @@ def data_gen_for_casual_lm(): num_attention_heads=4, max_position_embeddings=128, num_labels=16, + attn_implementation="eager", ) if hasattr(config, "pad_token_id"): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 27f904292597..2a10d86c79bb 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -32,7 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, model_fn, loss_fn, test_config ) if enable_gradient_checkpointing: - org_model.gradient_checkpointing_enable() + # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( From 2e2d1c1fc91aeaacd83998f447fa133773694bee Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:24:01 +0800 Subject: [PATCH 38/44] fix the opt upgrade (#5634) --- colossalai/shardformer/modeling/opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index dd3afecd6723..76534b5d5d2e 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -43,7 +43,7 @@ def _get_attention_mask( is_causal=True, ) else: - attention_mask = self.decoder._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, From e021cea20cc104a6481c61655bc717fb6d4daaab Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 24 Apr 2024 14:53:21 +0800 Subject: [PATCH 39/44] [shardformer] fix attn replacement (#5636) --- colossalai/shardformer/policies/falcon.py | 20 +++++------- colossalai/shardformer/policies/sam.py | 34 +++++++++++---------- colossalai/shardformer/policies/whisper.py | 16 ++++++++++ tests/kit/model_zoo/transformers/whisper.py | 1 - 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 628e9fdc0d96..09d895843b61 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -7,12 +7,7 @@ import colossalai.shardformer.layer as col_nn -from ..modeling.falcon import ( - FalconPipelineForwards, - build_falcon_alibi_tensor_fn, - get_falcon_flash_attention_forward, - get_tp_falcon_decoder_layer_forward, -) +from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["FalconPolicy"] @@ -30,7 +25,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel + from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn( @@ -141,11 +136,12 @@ def module_policy(self): ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={"forward": get_falcon_flash_attention_forward()}, - policy=policy, - target_key=FalconAttention, - ) + warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") + # self.append_or_create_method_replacement( + # description={"forward": get_falcon_flash_attention_forward()}, + # policy=policy, + # target_key=FalconAttention, + # ) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 498e62164b09..ce33925ff82e 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,6 +1,8 @@ +import warnings + import colossalai.shardformer.layer as col_nn -from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward +from ..modeling.sam import forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["SamPolicy", "SamModelPolicy"] @@ -15,7 +17,6 @@ def preprocess(self): def module_policy(self): from transformers.models.sam.modeling_sam import ( - SamAttention, SamTwoWayAttentionBlock, SamTwoWayTransformer, SamVisionAttention, @@ -210,20 +211,21 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_sam_flash_attention_forward(), - }, - policy=policy, - target_key=SamAttention, - ) - self.append_or_create_method_replacement( - description={ - "forward": get_sam_vision_flash_attention_forward(), - }, - policy=policy, - target_key=SamVisionAttention, - ) + warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.") + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamAttention, + # ) + # self.append_or_create_method_replacement( + # description={ + # "forward": get_sam_vision_flash_attention_forward(), + # }, + # policy=policy, + # target_key=SamVisionAttention, + # ) return policy diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 16ed2607c6f7..aeb6687971e5 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -48,6 +48,8 @@ def module_policy(self): WhisperDecoderLayer, WhisperEncoder, WhisperEncoderLayer, + WhisperFlashAttention2, + WhisperSdpaAttention, ) policy = {} @@ -242,6 +244,20 @@ def module_policy(self): policy=policy, target_key=WhisperAttention, ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperFlashAttention2, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperSdpaAttention, + ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( description={ diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index 0d9a581dfbe9..d69bebe6cc04 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -66,7 +66,6 @@ def data_gen_for_audio_classification(): encoder_ffn_dim=1536, encoder_layers=2, vocab_size=51866, - _attn_implementation="eager", ) # register the Whisper variants From fa0d8ab297326c4c725af29b944c892ad290bb74 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 14:53:51 +0800 Subject: [PATCH 40/44] [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/shardformer/policies/gpt2.py | 9 ++++++- colossalai/shardformer/policies/gptj.py | 9 ++++++- colossalai/shardformer/policies/llama.py | 24 +++++++++++++++---- colossalai/shardformer/policies/mistral.py | 23 ++++++++++++------ colossalai/shardformer/policies/opt.py | 14 ++++++++--- tests/kit/model_zoo/transformers/llama.py | 1 - .../test_model/test_shard_mistral.py | 2 +- 7 files changed, 63 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 98db7b948954..6f4f835a8dbe 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -35,13 +35,20 @@ def preprocess(self): Reshape the Embedding layer to make the embedding dimension divisible by world_size """ self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model + ATTN_IMPLEMENTATION = { + "eager": GPT2Attention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -186,7 +193,7 @@ def module_policy(self): "forward": get_gpt2_flash_attention_forward(), }, policy=policy, - target_key=GPT2Attention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: policy[GPT2Model].method_replacement = { diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 4b69137a6892..1280efaec921 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -30,13 +30,20 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel + ATTN_IMPLEMENTATION = { + "eager": GPTJAttention, + } + policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -160,7 +167,7 @@ def module_policy(self): "forward": get_gptj_flash_attention_forward(), }, policy=policy, - target_key=GPTJAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ff686a179553..1b30ae9c9f40 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -36,13 +36,27 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaModel, + LlamaSdpaAttention, + ) + ATTN_IMPLEMENTATION = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -93,7 +107,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) elif sp_mode == "all_to_all": decoder_attribute_replacement = { @@ -102,7 +116,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - policy[LlamaAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) self.append_or_create_method_replacement( @@ -110,7 +124,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) self.append_or_create_method_replacement( description={ @@ -221,7 +235,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, - target_key=LlamaAttention, + target_key=attn_cls, ) if self.pipeline_stage_manager is None: # replace llama model forward method diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 61e1b5f9c7b4..b3f89b4042c1 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -26,13 +26,26 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel + from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralFlashAttention2, + MistralModel, + ) + + ATTN_IMPLEMENTATION = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -128,10 +141,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_mistral_flash_attention_forward(), + "forward": get_mistral_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=MistralAttention, + target_key=attn_cls, ) return policy @@ -143,10 +156,6 @@ def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) method_replacement = {"forward": partial(new_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = {"forward": partial(new_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) - class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 2bb28b095114..2f6eabd5fef9 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -44,13 +44,21 @@ def config_sanity_check(self): def preprocess(self): self.tie_weight = self.tie_weight_check() + self.origin_attn_implement = self.model.config._attn_implementation return self.model def module_policy(self): - from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2 + + ATTN_IMPLEMENTATION = { + "eager": OPTAttention, + "flash_attention_2": OptFlashAttention2, + } policy = {} + attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -81,7 +89,7 @@ def module_policy(self): ] ) - policy[OPTAttention] = ModulePolicyDescription( + policy[attn_cls] = ModulePolicyDescription( attribute_replacement={ "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -151,7 +159,7 @@ def module_policy(self): "forward": get_opt_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=OPTAttention, + target_key=attn_cls, ) if not self.shard_config.pipeline_stage_manager: self.append_or_create_method_replacement( diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 08c05e9063bf..58b5b0487a82 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -65,7 +65,6 @@ def data_gen_for_casual_lm(): num_attention_heads=4, max_position_embeddings=128, num_labels=16, - attn_implementation="eager", ) if hasattr(config, "pad_token_id"): diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 07bc91b33b72..f127472aee0b 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -156,7 +156,7 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") +@pytest.mark.skip("something wrong with pipeline parallelism") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From 52f4d3a9c113c9532effd833d89eb3e5c295810f Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 24 Apr 2024 15:02:43 +0800 Subject: [PATCH 41/44] [test] fix llama test (#5638) --- tests/kit/model_zoo/transformers/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 58b5b0487a82..61fa560506c2 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -64,7 +64,6 @@ def data_gen_for_casual_lm(): intermediate_size=64, num_attention_heads=4, max_position_embeddings=128, - num_labels=16, ) if hasattr(config, "pad_token_id"): From fcceb78a7a7e2ee1c4bff966b034120df06d6e16 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 24 Apr 2024 16:06:27 +0800 Subject: [PATCH 42/44] [gemini] fix buffer cast (#5639) --- colossalai/zero/gemini/gemini_ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c79422171f1b..b25de1d68613 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -840,6 +840,7 @@ def _cast_buffers(self): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() + for buffer in self.module.buffers(): buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) From 2ad14bd1f3a9d8e3acaecc9acac31974abf4ec01 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 24 Apr 2024 17:12:57 +0800 Subject: [PATCH 43/44] Fix shardformer upgrade (#5640) * fix llama model * fix the mistral * fix the shardformer model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/shardformer/modeling/falcon.py | 141 --------------------- colossalai/shardformer/modeling/llama.py | 5 - colossalai/shardformer/modeling/mistral.py | 5 +- colossalai/shardformer/modeling/opt.py | 14 -- colossalai/shardformer/policies/falcon.py | 5 - 5 files changed, 1 insertion(+), 169 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 34754ecdbac9..df3b09c71cbc 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -6,7 +6,6 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.nn import functional as F from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, @@ -25,7 +24,6 @@ FalconForSequenceClassification, FalconForTokenClassification, FalconModel, - apply_rotary_pos_emb, build_alibi_tensor, ) from transformers.utils import logging @@ -171,145 +169,6 @@ def forward( return forward -def get_falcon_flash_attention_forward(): - try: - pass - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.falcon.modeling_falcon import FalconAttention - - def forward( - self: FalconAttention, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - kv_seq_len += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, kv_length, head_dim] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=-2) - value_layer = torch.cat((past_value, value_layer), dim=-2) - - kv_length = key_layer.shape[-2] - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - if alibi is None: - if self._use_sdpa and not output_attentions: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attention_scores = None - else: - attention_scores = query_layer @ key_layer.transpose(-1, -2) - attention_scores /= math.sqrt(self.head_dim) - - attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). - attn_output = attention_scores @ value_layer - - attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_scores - else: - return attn_output, present - else: - if self._use_sdpa and not output_attentions and head_mask is None: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - else: - matmul_result = query_layer @ key_layer.transpose(-1, -2) - - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) - - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - - # matmul: [batch_size * num_heads, q_length, head_dim] - attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) - - # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_probs - else: - return attn_output, present - - return forward - - class FalconPipelineForwards: """ This class serves as a micro library for falcon pipeline forwards. diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ac9baad5fdb9..0eb08a0432e7 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -33,7 +33,6 @@ from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d -from ..layer._operation import gather_forward_split_backward class LlamaPipelineForwards: @@ -115,10 +114,6 @@ def llama_model_forward( ) position_ids = position_ids.unsqueeze(0) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 3b876bcab96a..06b8f93d20b5 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -215,9 +215,6 @@ def forward( attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, past_key_value return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 76534b5d5d2e..8f841c8a6615 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -58,20 +58,6 @@ class OPTPipelineForwards: under pipeline setting. """ - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 09d895843b61..b82840576974 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -137,11 +137,6 @@ def module_policy(self): if self.shard_config.enable_flash_attention: warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") - # self.append_or_create_method_replacement( - # description={"forward": get_falcon_flash_attention_forward()}, - # policy=policy, - # target_key=FalconAttention, - # ) return policy def postprocess(self): From c253a7e17ae3c012feed234178c34525b7ba5f96 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 24 Apr 2024 22:05:54 +0800 Subject: [PATCH 44/44] [shardformer]support pipeline parallelism for mistral. (#5642) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme * [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * support pp for mistral * fix * fix fix fix * fix --------- Co-authored-by: Hongxin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tong Li Co-authored-by: binmakeswell --- colossalai/shardformer/modeling/mistral.py | 481 ++++++++++++++++-- colossalai/shardformer/policies/falcon.py | 1 + colossalai/shardformer/policies/gpt2.py | 1 + colossalai/shardformer/policies/gptj.py | 1 + colossalai/shardformer/policies/mistral.py | 152 +++++- colossalai/shardformer/policies/opt.py | 1 + requirements/requirements.txt | 2 +- tests/kit/model_zoo/transformers/mistral.py | 3 + .../test_model/test_shard_mistral.py | 21 +- 9 files changed, 604 insertions(+), 59 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 06b8f93d20b5..ac7845400d8d 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -2,12 +2,22 @@ from typing import List, Optional, Tuple, Union import torch -from transformers.cache_utils import Cache +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.mistral.modeling_mistral import MistralModel +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel from transformers.utils import logging +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.shard import ShardConfig + +from ..layer import ColoAttention + logger = logging.get_logger(__name__) @@ -24,6 +34,10 @@ def mistral_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if use_cache: logger.warning_once("use_cache=True is not supported for Mistral models at the moment.") @@ -35,6 +49,376 @@ def mistral_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if attention_mask is not None and self._use_flash_attention_2 and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + else: + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_causal_lm_forward( + self: MistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def mistral_for_sequence_classification_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = MistralForwards.mistral_model_forward( + self.model, + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + batch_size = hidden_states.shape[0] + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self: MistralModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -47,6 +431,12 @@ def mistral_model_forward( past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -67,19 +457,29 @@ def mistral_model_forward( " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length, seq_length) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, ) + else: + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) hidden_states = inputs_embeds @@ -93,6 +493,7 @@ def mistral_model_forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -121,7 +522,7 @@ def mistral_model_forward( hidden_states = layer_outputs[0] if use_cache: - layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -133,6 +534,8 @@ def mistral_model_forward( all_hidden_states += (hidden_states,) next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -143,11 +546,11 @@ def mistral_model_forward( attentions=all_self_attns, ) + return forward -def get_mistral_flash_attention_forward(): - from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention +def get_mistral_flash_attention_forward(shard_config: ShardConfig): + from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv def forward( self: MistralAttention, @@ -164,15 +567,14 @@ def forward( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = ( - self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - ) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -190,31 +592,18 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type - ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, past_key_value + return attn_output, None, past_key_value return forward diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index b82840576974..e72a97e4bfc0 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -137,6 +137,7 @@ def module_policy(self): if self.shard_config.enable_flash_attention: warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.") + return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 36d081c17991..6f4f835a8dbe 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -48,6 +48,7 @@ def module_policy(self): policy = {} attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index f40a17ad9487..1280efaec921 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -43,6 +43,7 @@ def module_policy(self): policy = {} attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b4d25a2fe85e..b5018e47d65d 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,8 +1,10 @@ import warnings from functools import partial -from typing import Callable, Dict, Union +from typing import Callable, Dict, List, Union import torch.nn as nn +from torch import Tensor +from torch.nn import Module from colossalai.shardformer.layer import ( FusedRMSNorm, @@ -14,7 +16,11 @@ VocabParallelLMHead1D, ) -from ..modeling.mistral import MistralForwards, get_mistral_flash_attention_forward +from ..modeling.mistral import ( + MistralForwards, + get_mistral_flash_attention_forward, + get_mistral_model_forward_for_flash_attn, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] @@ -45,6 +51,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -145,16 +152,83 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=attn_cls, ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_mistral_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=MistralModel, + ) return policy def postprocess(self): return self.model - def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - method_replacement = {"forward": partial(new_forward)} + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MistralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + return held_layers + class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: @@ -164,17 +238,28 @@ def module_policy(self): policy = super().module_policy() from transformers.models.mistral.modeling_mistral import MistralModel - self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy + ) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in mistral model""" + return [] + class MistralForCausalLMPolicy(MistralPolicy): def module_policy(self): from transformers import MistralForCausalLM policy = super().module_policy() - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm @@ -207,8 +292,38 @@ def module_policy(self): policy.update(new_item) + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy + ) + return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + mistral_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: mistral_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + class MistralForSequenceClassificationPolicy(MistralPolicy): def module_policy(self): @@ -227,9 +342,26 @@ def module_policy(self): ] ) } + policy.update(new_item) - if self.pipeline_stage_manager: - warnings.warn("Mistral doesn't support pipeline parallelism now.") + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MistralForSequenceClassification, + new_forward=MistralForwards.mistral_for_sequence_classification_forward, + policy=policy, + ) - policy.update(new_item) return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index e589c674088f..2f6eabd5fef9 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -58,6 +58,7 @@ def module_policy(self): policy = {} attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b0352230788a..d307312ded8e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,4 +16,4 @@ ray sentencepiece google protobuf -transformers==4.36.0 +transformers==4.36.2 diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index 37f87585759e..ae5a9700240a 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -52,6 +52,9 @@ def data_gen_for_sequence_classification(): hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258 ) +if hasattr(config, "pad_token_id"): + config.pad_token_id = config.eos_token_id + model_zoo.register( name="transformers_mistral", model_fn=lambda: transformers.MistralModel(config), diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index f127472aee0b..05c1998146b6 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -91,7 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32": - atol, rtol = 1e-4, 1e-3 + atol, rtol = 2e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 check_weight( @@ -114,6 +114,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, @@ -156,7 +174,6 @@ def check_mistral(rank, world_size, port): run_mistral_test() -@pytest.mark.skip("something wrong with pipeline parallelism") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()