Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,10 @@ def __init__(
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
)

self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
Expand Down
40 changes: 19 additions & 21 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None

@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
Expand All @@ -441,21 +441,17 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."

sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)

if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size():
# single node, no need to consider NICs
return sp_group, sp_group
if sp_size <= 4:
inner_ring_size = min(2, sp_size)
else:
inner_ring_size = min(4, sp_size)
Comment thread
wangbluo marked this conversation as resolved.
else:
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
assert inner_ring_size is not None

assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"

if inner_ring_size == sp_size:
return sp_group, sp_group
Expand All @@ -474,14 +470,14 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inner_ring_group = group

# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inter_ring_group = group

Expand All @@ -492,7 +488,7 @@ def attention(
q, # (B, H, Sq, D)
k,
v,
sp_group,
sp_axis,
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,
Expand All @@ -502,6 +498,7 @@ def attention(
deterministic=False,
return_softmax=False,
inner_ring_size=None,
pg_mesh=None,
**kwargs,
):
"""
Expand All @@ -512,7 +509,7 @@ def attention(
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
Expand All @@ -537,7 +534,6 @@ def attention(
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()

assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
Expand All @@ -546,11 +542,13 @@ def attention(
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet."

clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."

if RingAttention.SP_GROUP is not sp_group:
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
sp_group = pg_mesh.get_group_along_axis(sp_axis)
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,17 +857,17 @@ def forward(
dropout_p = self.attn_dropout.p if self.training else 0.0

sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
key,
value,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,10 @@ def forward(
query_states,
key_states,
value_states,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)

elif shard_config.enable_flash_attention:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict)

# For ring attention
sp_axis: Optional[int] = None
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
Expand Down
20 changes: 13 additions & 7 deletions tests/test_shardformer/test_layer/test_ring_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.testing import assert_close

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
Expand All @@ -17,11 +18,14 @@
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype):
def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size()
sp_axis = 2
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
Expand All @@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
q,
k,
v,
sp_group,
sp_axis,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=max(2, sp_size // 2),
# inner_ring_size=4
inner_ring_size=inner_ring_size,
pg_mesh=pg_mesh,
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
Expand Down Expand Up @@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
sp_axis = 2
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
Expand Down Expand Up @@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
q_ring,
k_ring,
v_ring,
sp_group,
sp_axis,
**mask_info,
pad_output=False,
return_softmax=True,
pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1),
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
Expand Down Expand Up @@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
check_ring_attn()
check_ring_attn(inner_ring_size=None)


def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn()
Comment thread
wangbluo marked this conversation as resolved.
check_ring_attn(inner_ring_size=2)


@rerun_if_address_is_in_use()
Expand Down