diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index c9650ce4f712..f05b0d1180fb 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -105,7 +105,20 @@ We will follow this roadmap to develop Shardformer: - [ ] Whisper - [ ] Multi-modal - [ ] To be added - +- [ ] Flash Attention Support + - [ ] NLP + - [x] BERT + - [x] T5 + - [x] LlaMa + - [x] GPT2 + - [x] OPT + - [x] BLOOM + - [ ] GLM + - [ ] RoBERTa + - [ ] ALBERT + - [ ] ERNIE + - [ ] GPT Neo + - [ ] GPT-J ## 💡 API Design We will discuss the major components of `ShardFormer` below to help you better understand how things work. @@ -396,7 +409,7 @@ In the case of using 2 GPUs, the training times are as follows.

-In the case of using 4 GPUs, the training times are as follows. +In the case of using 4 GPUs, the training times are as follows. | N_CTX | org_model | shard_model | | :------: | :-----: | :-----: | @@ -419,7 +432,7 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. | accuracy | f1 | loss | GPU number | model shard | | :------: | :-----: | :-----: | :--------: | :---------: | diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index faa579529ce4..b69ac63d6a8c 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,22 +1,22 @@ -from typing import Optional, Tuple import math +from typing import Optional, Tuple import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import functional as F -__all__ = ['get_bert_forward'] -def get_bert_forward(): +def get_bert_flash_attention_forward(): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def bert_flash_attention_forward( - self, + from transformers.models.bert.modeling_bert import BertAttention + + def forward( + self: BertAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -67,16 +67,14 @@ def bert_flash_attention_forward( if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1) else: position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -99,8 +97,13 @@ def bert_flash_attention_forward( query_layer = query_layer.permute(0, 2, 1, 3).contiguous() key_layer = key_layer.permute(0, 2, 1, 3).contiguous() value_layer = value_layer.permute(0, 2, 1, 3).contiguous() - - context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale) + + context_layer = me_attention(query_layer, + key_layer, + value_layer, + attn_bias=final_attention_mask, + p=self.dropout.p, + scale=scale) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) @@ -109,5 +112,31 @@ def bert_flash_attention_forward( if self.is_decoder: outputs = outputs + (past_key_value,) return outputs - - return bert_flash_attention_forward \ No newline at end of file + + return forward + + +def get_jit_fused_bert_self_output_forward(): + + from transformers.models.bert.modeling_bert import BertSelfOutput + + def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward + + +def get_jit_fused_bert_output_forward(): + + from transformers.models.bert.modeling_bert import BertOutput + + def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + return forward diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 4b6458ae0627..a586d39ec837 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -71,16 +71,17 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, return build_bloom_alibi_tensor -def get_bloom_forward(): + +def get_bloom_flash_attention_forward(enabel_jit_fused=False): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.bloom.modeling_bloom import dropout_add + from transformers.models.bloom.modeling_bloom import BloomAttention - def bloom_flash_attention_forward( - self, + def forward( + self: BloomAttention, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, @@ -88,7 +89,8 @@ def bloom_flash_attention_forward( layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, - output_attentions: bool = False,): + output_attentions: bool = False, + ): fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) @@ -114,30 +116,177 @@ def bloom_flash_attention_forward( present = (key_layer, value_layer) else: present = None - + tgt_len = key_layer.size()[1] - attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), dtype=torch.float32, device=query_layer.device, requires_grad=True) - attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta - attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min) - - context_layer = me_attention(query_layer, key_layer, value_layer, attn_bias=attention_numerical_mask, scale=self.inv_norm_factor, p=self.attention_dropout.p) + attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), + dtype=torch.float32, + device=query_layer.device, + requires_grad=True) + attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, + kv_length) * self.beta + attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, + torch.finfo(torch.float32).min) + + context_layer = me_attention(query_layer, + key_layer, + value_layer, + attn_bias=attention_numerical_mask, + scale=self.inv_norm_factor, + p=self.attention_dropout.p) context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) # TODO to replace with the bias_dropout_add function in jit - output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) outputs = (output_tensor, present, None) return outputs - - return bloom_flash_attention_forward \ No newline at end of file + + return forward + + +def get_jit_fused_bloom_attention_forward(): + + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, q_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + return forward + + +def get_jit_fused_bloom_mlp_forward(): + + from transformers.models.bloom.modeling_bloom import BloomMLP + + def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) + + if self.pretraining_tp > 1 and self.slow_but_exact: + intermediate_output = torch.zeros_like(residual) + slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp + for i in range(self.pretraining_tp): + intermediate_output = intermediate_output + F.linear( + hidden_states[:, :, int(i * slices):int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + intermediate_output = self.dense_4h_to_h(hidden_states) + output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) + return output + + return forward + + +def get_jit_fused_bloom_gelu_forward(): + + from transformers.models.bloom.modeling_bloom import BloomGelu + + from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction + + def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: + bias = torch.zeros_like(x) + if self.training: + return JitGeLUFunction.apply(x, bias) + else: + return self.bloom_gelu_forward(x, bias) + + return forward diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index fbb1691ec595..ba789eb6458e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -2,15 +2,15 @@ import torch -__all__ = ['get_gpt2_forward'] +def get_gpt2_flash_attention_forward(): -def get_gpt2_forward(): + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention - def gpt2_flash_attention_forward( - self, + def forward( + self: GPT2Attention, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, @@ -77,7 +77,7 @@ def gpt2_flash_attention_forward( return outputs - return gpt2_flash_attention_forward + return forward def split_heads(tensor, num_heads, attn_head_size): diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py new file mode 100644 index 000000000000..6434348ef823 --- /dev/null +++ b/colossalai/shardformer/modeling/jit.py @@ -0,0 +1,34 @@ +import torch + + +def get_dropout_add_func(): + + from transformers.models.bloom.modeling_bloom import dropout_add + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + return dropout_add(x, residual, prob, training) + + return self_dropout_add + + +def get_jit_fused_dropout_add_func(): + + from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train + + def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + bias = torch.zeros_like(x) + if training: + return bias_dropout_add_fused_train(x, bias, residual, prob) + return bias_dropout_add_fused_inference(x, bias, residual, prob) + + return self_dropout_add + + +def get_jit_fused_gelu_forward_func(): + + from colossalai.kernel.jit.bias_gelu import bias_gelu + + def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: + return bias_gelu(bias, x) + + return bloom_gelu_forward diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ff2e3e00d048..83aa01716832 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -2,33 +2,15 @@ import torch -__all__ = ['get_llama_forward'] +def get_llama_flash_attention_forward(): -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def get_llama_forward(): + from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention - def llama_flash_attention_forward( - self, + def forward( + self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -82,4 +64,4 @@ def llama_flash_attention_forward( return attn_output, None, past_key_value - return llama_flash_attention_forward + return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index b317a1063034..36dd5788038b 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,16 +1,17 @@ from typing import Optional, Tuple import torch +from torch import nn -__all__ = ['get_opt_forward'] +def get_opt_flash_attention_forward(): -def get_opt_forward(): + from transformers.models.opt.modeling_opt import OPTAttention from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention - def opt_flash_attention_forward( - self, + def forward( + self: OPTAttention, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -65,6 +66,7 @@ def opt_flash_attention_forward( if layer_head_mask.size() != (self.num_heads,): raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}") + flash_attention_mask = None attn_mask_type = AttnMaskType.causal if attention_mask != None: @@ -87,4 +89,87 @@ def opt_flash_attention_forward( attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value - return opt_flash_attention_forward + return forward + + +def get_jit_fused_opt_decoder_layer_forward(): + + from transformers.models.opt.modeling_opt import OPTDecoderLayer + + def forward( + self: OPTDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + + hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + return forward diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index a23b24e6f571..14befc1f54e9 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -1,29 +1,30 @@ -from typing import Optional +from typing import Optional, Tuple import torch -__all__ = ['get_t5_forward'] +__all__ = ['get_t5_flash_attention_forward'] -def get_t5_forward(): +def get_t5_flash_attention_forward(): try: from xformers.ops import memory_efficient_attention as me_attention except: raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def t5_flash_attention_forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): + from transformers.models.t5.modeling_t5 import T5Attention + + def forward( + self: T5Attention, + hidden_states: torch.Tensor, + mask: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + layer_head_mask: Optional[torch.Tensor] = None, + query_length: Optional[int] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ @@ -128,4 +129,83 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return outputs - return t5_flash_attention_forward + return forward + + +def get_jit_fused_T5_layer_ff_forward(): + + from transformers.models.t5.modeling_t5 import T5LayerFF + + def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training) + return hidden_states + + return forward + + +def get_T5_layer_self_attention_forward(): + + from transformers.models.t5.modeling_t5 import T5LayerSelfAttention + + def forward( + self: T5LayerSelfAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward + + +def get_T5_layer_cross_attention_forward(): + + from transformers.models.t5.modeling_t5 import T5LayerCrossAttention + + def forward( + self: T5LayerCrossAttention, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + query_length: Optional[int] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index a213538f8809..e9f6a3f094cb 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -3,8 +3,13 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ +from ..modeling.bert import ( + get_bert_flash_attention_forward, + get_jit_fused_bert_output_forward, + get_jit_fused_bert_self_output_forward, +) +from ..modeling.jit import get_jit_fused_dropout_add_func from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.bert import get_bert_forward __all__ = [ 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', @@ -32,7 +37,13 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertSelfAttention + from transformers.models.bert.modeling_bert import ( + BertEmbeddings, + BertLayer, + BertOutput, + BertSelfAttention, + BertSelfOutput, + ) policy = {} @@ -121,11 +132,22 @@ def module_policy(self): )], policy=policy, target_key=BertEmbeddings) - + # use flash attention if self.shard_config.enable_flash_attention: policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_bert_forward(), + 'forward': get_bert_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bert_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BertOutput] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bert_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), }) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index daae3cf881a2..b2c86ec6c632 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -3,7 +3,14 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.bloom import build_bloom_alibi_tensor_fn, get_bloom_forward +from ..modeling.bloom import ( + build_bloom_alibi_tensor_fn, + get_bloom_flash_attention_forward, + get_jit_fused_bloom_attention_forward, + get_jit_fused_bloom_gelu_forward, + get_jit_fused_bloom_mlp_forward, +) +from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -25,7 +32,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, BloomAttention + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel policy = {} @@ -101,10 +108,26 @@ def module_policy(self): ], policy=policy, target_key=BloomBlock) - + if self.shard_config.enable_flash_attention: policy[BloomAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_bloom_forward(), + 'forward': get_bloom_flash_attention_forward(), + 'dropout_add': get_dropout_add_func() + }) + + # enable jit fused operator + if self.shard_config.enable_jit_fused: + policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bloom_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bloom_mlp_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_bloom_gelu_forward(), + 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), }) return policy diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index cd11bdb39f94..f4cc18024f6d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -3,8 +3,8 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ +from ..modeling.gpt2 import get_gpt2_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.gpt2 import get_gpt2_forward __all__ = [ 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', @@ -30,7 +30,7 @@ def preprocess(self): return self.model def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model, GPT2Attention + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} @@ -107,11 +107,11 @@ def module_policy(self): ], policy=policy, target_key=GPT2Block) - + if self.shard_config.enable_flash_attention: policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ - 'forward': get_gpt2_forward(), - }) + 'forward': get_gpt2_flash_attention_forward(), + }) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 7664492f7686..c4ed5bb0e342 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -4,8 +4,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from ..modeling.llama import get_llama_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.llama import get_llama_forward __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -27,7 +27,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaAttention + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel policy = {} @@ -99,11 +99,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), policy=policy, target_key=LlamaModel) - + if self.shard_config.enable_flash_attention: policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_llama_forward(), - }) + 'forward': get_llama_flash_attention_forward(), + }) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index e960227415cc..f363096158ae 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,7 +1,8 @@ from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .._utils import getattr_, setattr_ -from ..modeling.opt import get_opt_forward +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -93,7 +94,14 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: policy[OPTAttention] = ModulePolicyDescription(method_replacement={ - 'forward': get_opt_forward(), + 'forward': get_opt_flash_attention_forward(), + }) + + # use jit fused operator + if self.shard_config.enable_jit_fused: + policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_opt_decoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), }) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 5463878b04b3..846451ab57ee 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -9,8 +9,14 @@ from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription from .._utils import getattr_, setattr_ +from ..modeling.jit import get_jit_fused_dropout_add_func +from ..modeling.t5 import ( + get_jit_fused_T5_layer_ff_forward, + get_t5_flash_attention_forward, + get_T5_layer_cross_attention_forward, + get_T5_layer_self_attention_forward, +) from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from ..modeling.t5 import get_t5_forward __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] @@ -166,7 +172,22 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: policy[T5Attention] = ModulePolicyDescription(method_replacement={ - 'forward': get_t5_forward(), + 'forward': get_t5_flash_attention_forward(), + }) + + # use jit operator + if self.shard_config.enable_jit_fused: + policy[T5LayerFF] = ModulePolicyDescription(method_replacement={ + 'forward': get_jit_fused_T5_layer_ff_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_T5_layer_self_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }) + policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={ + 'forward': get_T5_layer_cross_attention_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), }) return policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4076becdbbc9..f9d93609feef 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -22,6 +22,7 @@ class ShardConfig: enable_fused_normalization: bool = False enable_all_optimization: bool = False enable_flash_attention: bool = False + enable_jit_fused: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -51,3 +52,4 @@ def _turn_on_all_optimization(self): # you can add all the optimization flag here self.enable_fused_normalization = True self.enable_flash_attention = True + self.enable_jit_fused = True diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 44a8f841a6d4..4dcaac143525 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -6,14 +6,16 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, - enable_flash_attention=False): + enable_flash_attention=False, + enable_jit_fused=False): # create new model org_model = model_fn().cuda() # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention) + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 865ba2fa7126..a025b0cee0b3 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -70,10 +70,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 63438bbe9a91..c17d55a9126c 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -70,10 +70,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 9ca0669340a5..0eb04960460b 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -73,11 +73,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention) + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index ff4b71fe76a4..cf2efeb11f43 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -83,10 +83,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_tensor_parallelism', [True, False]) @parameterize('enable_flash_attention', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): +@parameterize('enable_jit_fused', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, + enable_flash_attention, enable_jit_fused) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache()