Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,37 +380,45 @@ def __init__(
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
self.pp_style = pp_style
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=num_model_chunks
self.pg_mesh,
pipeline_axis=PP_AXIS,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks
)

if self.pp_style == "interleaved":
if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else:
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size
)
else:
raise NotImplementedError()

self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
scheduler=self.schedule,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
Expand Down
170 changes: 72 additions & 98 deletions colossalai/pipeline/schedule/interleaved_pp.py

Large diffs are not rendered by default.

89 changes: 46 additions & 43 deletions colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@ class PipelineStageManager:
"""

def __init__(
self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False, num_model_chunks=1
self,
pg_mesh: ProcessGroupMesh,
pipeline_axis: int,
enable_interleave: bool = False,
num_model_chunks: int = 1,
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"

self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}

# init prev and next coord
coord = self.pg_mesh.coordinate()
# the prev rank of rank0 is the last rank
Expand All @@ -34,10 +41,6 @@ def __init__(
# the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks = num_model_chunks
self.model_chunk_id = 0
self.layers = Optional[List[List[int]]]

# init p2p process groups
stages = list(range(self.num_stages))
Expand All @@ -47,46 +50,60 @@ def __init__(
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group

if is_virtual:
self.is_interleave = enable_interleave
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank
# only used in interleaved pipeline for now
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group

def is_first_stage(self) -> bool:
"""Is the current stage the first stage.
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks

Returns:
bool: Whether the current stage is the first stage.
"""
return self.stage == 0 and self.model_chunk_id == 0
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None

def is_last_stage(self) -> bool:
"""Is the current stage the last stage.

Returns:
bool: Whether the current stage is the last stage.
"""
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""Is the current stage the first stage.

# introduced due to interleaved pipeline parallel, as the first/last device may also hold intermediate stages
def is_first_device(self) -> bool:
"""Is the current stage on the first device.
NOTE:
1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device.
2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device()

Returns:
bool: Whether the current stage is on the first device.
bool: Whether the current stage is the first stage.
"""
return self.stage == 0
if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
assert self.is_interleave ^ (model_chunk_id is None), \
"model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == 0
else:
return self.stage == 0 and model_chunk_id == 0

def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool:
"""Is the current stage the last stage.

def is_last_device(self) -> bool:
"""Is the current stage on the last device.
NOTE:
1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device.
2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device()

Returns:
bool: Whether the current stage on the last device.
bool: Whether the current stage is the last stage.
"""
return self.stage == self.num_stages - 1
if self.is_interleave and model_chunk_id is None:
model_chunk_id = self.model_chunk_id
assert self.is_interleave ^ (model_chunk_id is None), \
"model_chunk_id must be specified when using interleaved pipeline"
if not self.is_interleave or model_chunk_id < 0:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1

@property
def num_stages(self) -> int:
Expand Down Expand Up @@ -154,17 +171,3 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
ProcessGroup: Process group of the given stages.
"""
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)

def set_interleaved_model_chunk_id(self, model_chunk_id: int):
"""For interleaved pipeline parallel, set the model chunk id for the device at the current stage.
Args:
model_chunk_id (int): the id of the current model chunk for the device.
"""
self.model_chunk_id = model_chunk_id

def set_interleaved_device_layers(self, layers: List[List[int]]):
"""For interleaved pipeline parallel, set the layer chunks for the device.
Args:
layers (List[List[int]]): list of layer chunks for the device.
"""
self.layers = layers
63 changes: 29 additions & 34 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch.nn as nn
from torch import Tensor
from torch.nn import Module

from colossalai.pipeline.schedule import PipelineSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager

from ..layer.parallel_module import ParallelModule
Expand Down Expand Up @@ -100,12 +99,6 @@ def pipeline_stage_manager(self) -> Optional[PipelineStageManager]:
return self.shard_config.pipeline_stage_manager
return None

@property
def scheduler(self) -> Optional[PipelineSchedule]:
if self.shard_config is not None:
return self.shard_config.scheduler
return None

@abstractmethod
def config_sanity_check(self):
"""
Expand Down Expand Up @@ -222,29 +215,31 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]:

@staticmethod
def get_stage_index(
layers_per_stage: List[int], stage: int, num_stages=None, num_model_chunks=1
) -> Union[List[int], List[List[int]]]:
# num_stages info is only needed for interleaved pipeline stage assignment
if num_stages is None:
"""
get the start index and end index of layers for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

start_idx = num_layers_per_stage_accumulated[stage]
end_idx = num_layers_per_stage_accumulated[stage + 1]

return [start_idx, end_idx]
else:
"""
interleaved pipeline: get the start index and end index PAIRS for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indexes = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indexes.append([start_idx, end_idx])

return stage_indexes
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
num_stages: int = 0,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.

Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks

Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk

"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices
Loading