Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions applications/Chat/coati/models/bloom/triton_attention_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import math
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.distributed as dist
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.model_bloom import BloomAttention

from colossalai.kernel.triton.ops import compute_attention_for_bloom

def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
"""
Dropout add function

Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out

class TritonBloomAttention(BloomAttention):
def __init__(self, config: BloomConfig):
super(TritonBloomAttention, self).__init__(config)

def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
q_length = query_layer.shape[1]
batch_size = query_layer.shape[0]
num_heads = query_layer.shape[2]

query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=2)
value_layer = torch.cat((past_value, value_layer), dim=1)

_, _, kv_length = key_layer.shape
alibi = alibi.view(batch_size, num_heads, q_length, -1)

context_layer = compute_attention_for_bloom(q = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim),
k = key_layer.view(batch_size, self.num_heads, self.head_dim, kv_length),
v = value_layer.view(batch_size, self.num_heads, kv_length, self.head_dim),
alibi = alibi,
beta = self.beta,
scale = self.inv_norm_factor,
attention_mask = attention_mask,
drop_out = self.hidden_dropout,
head_mask = head_mask,
layer_past = layer_past,
use_cache = True,
)

if use_cache:
present = (key_layer, value_layer)
else:
present = None

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)

outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)

return outputs
11 changes: 10 additions & 1 deletion applications/Chat/coati/trainer/strategies/tp/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from typing import Dict, Type

import torch.nn as nn
from coati.models.lora import LoraLinear
from torch.nn import functional as F
from torch.nn import Module
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomForCausalLM, BloomMLP
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer

from coati.models.lora import LoraLinear
from coati.modles.bloom.triton_attention_forward import TritonBloomAttention
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.lazy import LazyTensor
Expand Down Expand Up @@ -132,6 +135,12 @@ def replace(self, module: Module) -> bool:
module.forward = MethodType(bloom_attn_fwd, module)
return False

class BloomAttentionTritonPolicy(Policy):

def replace(self, module: Module) -> bool:
assert isinstance(module, BloomAttention)
module.forward = MethodType(TritonBloomAttention.forward, module)
return False

class BloomMLPPolicy(Policy):

Expand Down
Loading