Skip to content
Merged
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
zero_stage: int = 0,
cpu_offload: bool = False,
enable_fused_normalization: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
Expand All @@ -170,6 +171,7 @@ def __init__(
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_fused_normalization = enable_fused_normalization
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
Expand All @@ -184,7 +186,8 @@ def __init__(
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_fused_normalization=self.enable_fused_normalization)
enable_fused_normalization=self.enable_fused_normalization,
enable_sequence_parallelism=enable_sequence_parallelism)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
Expand Down
26 changes: 19 additions & 7 deletions colossalai/kernel/cuda_native/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import os
import subprocess
import warnings

import torch

Expand All @@ -14,15 +15,20 @@
HAS_MEM_EFF_ATTN = True
except ImportError:
HAS_MEM_EFF_ATTN = False
print('please install xformers from https://github.com/facebookresearch/xformers')
warnings.warn(f'please install xformers from https://github.com/facebookresearch/xformers')

if HAS_MEM_EFF_ATTN:

from typing import Optional

from einops import rearrange
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)

from .scaled_softmax import AttnMaskType

Expand Down Expand Up @@ -86,11 +92,14 @@ def backward(ctx, grad_output):

class ColoAttention(torch.nn.Module):

def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert embed_dim % num_heads == 0, \
f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
self.scale = 1 / math.sqrt(embed_dim // num_heads)
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout

@staticmethod
Expand All @@ -116,7 +125,7 @@ def forward(self,
bias: Optional[torch.Tensor] = None):
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
attn_bias = None
if attn_mask_type == AttnMaskType.padding: # bert style
if attn_mask_type and attn_mask_type.value % 2 == 1: # bert style
assert attn_mask is not None, \
f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, \
Expand All @@ -134,7 +143,10 @@ def forward(self,
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
if attn_mask_type == AttnMaskType.padding:
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.paddedcausal:
attn_bias = BlockDiagonalCausalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.causal: # gpt style
attn_bias = LowerTriangularMask()

Expand All @@ -146,7 +158,7 @@ def forward(self,

out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)

if attn_mask_type == AttnMaskType.padding and batch_size > 1:
if attn_mask_type and attn_mask_type.value % 2 == 1 and batch_size > 1:
out = self.repad(out, q_indices, batch_size, tgt_len)

out = rearrange(out, 'b s h d -> b s (h d)')
Expand Down
5 changes: 3 additions & 2 deletions colossalai/kernel/cuda_native/scaled_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3


class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
Expand Down Expand Up @@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
Expand All @@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0

if self.attn_mask_type == AttnMaskType.causal:
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"

# input is 3D tensor (attn_batches, sq, sk)
Expand Down
Loading