Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
0593f04
sequence parallel optimization
KKZ20 Dec 9, 2023
fe5fac2
validate sequence parallel in llama (code to be polished)
KKZ20 Jan 3, 2024
ee95f94
shardformer api writing
KKZ20 Jan 5, 2024
98a2eeb
integrate sequence parallel in ShardFormer
KKZ20 Jan 8, 2024
28c11b7
fix pp bugs and sp bugs for LlaMa model
KKZ20 Jan 9, 2024
cd41e42
integrating ring-based sequence parallelism into ShardFormer
KKZ20 Jan 10, 2024
391dc64
fix bugs when useing sp and flashattention together
KKZ20 Jan 10, 2024
13fc14c
fix operation function name
KKZ20 Jan 12, 2024
83e6044
support flash attention for ulysses-style sp
KKZ20 Jan 17, 2024
7557691
clarify sp process group
KKZ20 Jan 17, 2024
9698a87
fix compatibility bugs in moe plugin
KKZ20 Jan 17, 2024
7a31083
fix fused linear bugs
KKZ20 Jan 17, 2024
74457df
fix linear layer test
KKZ20 Jan 17, 2024
858f55d
support gpt model all-to-all sp
KKZ20 Jan 23, 2024
0b115b4
modify shard data dimension (meant to be dim=-1)
KKZ20 Jan 23, 2024
d146040
support megtron-style sp and distributed attn for llama model
linsj20 Jan 23, 2024
362b5b6
finish sp mode 3 support for gpt
KKZ20 Jan 24, 2024
7293b16
using all_to_all_single when batch size is 1
KKZ20 Jan 24, 2024
65db8b2
support mode 2 sp in gpt2 (#5)
linsj20 Jan 26, 2024
e72bd87
polish code
KKZ20 Jan 26, 2024
2076bcf
enable distributed attn mask when using sp mode 2 and 3 in llama
KKZ20 Feb 1, 2024
bb18577
automatically enable flash attn when using sp mode 2 and 3 in llama
KKZ20 Feb 1, 2024
9788fd8
inplace attn mask
KKZ20 Feb 1, 2024
544a06d
add zero2 support for sequence parallel
KKZ20 Feb 19, 2024
c3d0e83
polish code
KKZ20 Feb 27, 2024
9f2f1fe
fix bugs
KKZ20 Feb 27, 2024
33963a3
fix gemini checkpoint io
KKZ20 Feb 27, 2024
700c26d
loose tensor checking atol and rtol
KKZ20 Feb 28, 2024
9a36add
add comment
KKZ20 Mar 11, 2024
0e0ac18
fix llama layernorm grad
KKZ20 Mar 13, 2024
cbb3025
fix zero grad
KKZ20 Mar 13, 2024
3391d3e
fix zero grad
KKZ20 Mar 13, 2024
cc28bd4
fix conflict
KKZ20 Mar 13, 2024
1a3825d
update split and gather auto grad func
KKZ20 Mar 18, 2024
76a22da
sequence parallel: inside text split (#6)
linsj20 Mar 20, 2024
7e80cc4
polish code (part 1)
KKZ20 Mar 25, 2024
eff6978
polish code (part 2)
KKZ20 Mar 25, 2024
26f7bf8
polish code (part 2.5)
KKZ20 Mar 25, 2024
2beac05
polish code (part 3)
linsj20 Mar 26, 2024
e5dcd93
polish code
KKZ20 Mar 27, 2024
ace07c9
fix ulysses style ZeRO
linsj20 Mar 27, 2024
56a5ba8
fix llama and gpt sp
KKZ20 Mar 28, 2024
2a30925
Merge branch 'main' into rebase/sp
KKZ20 Mar 28, 2024
93c958f
polish code
KKZ20 Apr 1, 2024
48580c7
move ulysses grad sync to ddp (#9)
linsj20 Apr 2, 2024
aea4fb6
remove zero_stage and unbind the grad sync for alltoall sp
KKZ20 Apr 2, 2024
07ae37b
add 2d group creation test
linsj20 Apr 3, 2024
145e879
remove useless code
KKZ20 Apr 3, 2024
7c31455
change shard config not to enable sp when enable_all_optimizations
KKZ20 Apr 3, 2024
794800a
add sp warnings for several model
KKZ20 Apr 3, 2024
daec9e8
remove useless code
KKZ20 Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@

from .pp_plugin_base import PipelinePluginBase

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]

PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}

Expand All @@ -53,6 +54,7 @@ def __init__(
shard_config: ShardConfig,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
Expand All @@ -61,6 +63,7 @@ def __init__(
self.shard_config = shard_config
self.dp_group = dp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.use_dpp = use_ddp
self.require_grad_sync = True

Expand Down Expand Up @@ -168,13 +171,24 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
Returns:
None
"""
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:

if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode == "all_to_all":
return

if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
# If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized
# across the tensor parallelism group.
group = self.tp_group
else:
raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")

if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads)
else:
# Synchronize gradients from the model across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module)

def forward(self, *args, **kwargs):
if self.convert_fn is not None:
Expand Down Expand Up @@ -727,10 +741,9 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)

if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else:
return

Expand Down Expand Up @@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
sp_size (int): The size of sequence parallelism.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
Expand All @@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
Expand Down Expand Up @@ -937,13 +952,15 @@ def __init__(
self,
tp_size: int,
pp_size: int,
sp_size: int = None,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
Expand Down Expand Up @@ -974,14 +991,41 @@ def __init__(
super().__init__()
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"

if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
assert (
self.sequence_parallelism_mode in SUPPORT_SP_MODE
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
assert (
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]:
assert (
tp_size == 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
assert (
pp_size == 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
sp_size == 1 or sp_size is None
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
self.sp_size = 1

self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
Expand All @@ -990,7 +1034,7 @@ def __init__(
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
Expand Down Expand Up @@ -1031,16 +1075,22 @@ def __init__(
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
)
Expand Down Expand Up @@ -1110,13 +1160,23 @@ def configure(
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
else:
dp_group = self.dp_group
model = HybridParallelModule(
model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
dp_group=dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
Expand Down Expand Up @@ -1146,7 +1206,8 @@ def configure(
tp_process_group=self.tp_group,
)
else:
if self.dp_size == 1:
zero_dp_size = dist.get_world_size(dp_group)
if zero_dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
Expand All @@ -1158,7 +1219,7 @@ def configure(
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
dp_process_group=dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True,
Expand Down
4 changes: 4 additions & 0 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def __init__(
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
Expand Down Expand Up @@ -365,6 +368,7 @@ def configure(
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
Expand Down
37 changes: 29 additions & 8 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:

@staticmethod
def get_coords_along_axis(
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis.

Expand All @@ -173,13 +173,28 @@ def get_coords_along_axis(
Returns:
List[Tuple[int, ...]]: Coordinates along the axis.
"""
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
if isinstance(axis, int):
axis = [axis,]
assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,]

def add_index(base_coord, axis, indices_at_axis):
coords_in_group = []
for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group

coords_in_group = [base_coord]
for ax, indices_at_ax in zip(axis, indices_at_axis):
new_coords_in_group = []
for coords in coords_in_group:
new_coords_in_group += add_index(coords, ax, indices_at_ax)
coords_in_group = new_coords_in_group
Comment thread
KKZ20 marked this conversation as resolved.

return coords_in_group

def create_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.

Expand All @@ -191,10 +206,17 @@ def create_group_along_axis(
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
"""
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
if isinstance(axis, int):
axis = [axis,]
if indices_at_axis is not None:
assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,]

indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
reduced_shape[axis] = 1
for ax in axis:
reduced_shape[ax] = 1
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
Expand Down Expand Up @@ -225,4 +247,3 @@ def get_group_along_axis(
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]

2 changes: 2 additions & 0 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .attn import AttnMaskType, ColoAttention
from ._operation import all_to_all_comm
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
Expand Down Expand Up @@ -26,4 +27,5 @@
"ParallelModule",
"AttnMaskType",
"ColoAttention",
"all_to_all_comm",
]
Loading