Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/example_check_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
timeout-minutes: 20
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
cancel-in-progress: true
Expand Down
27 changes: 23 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
Expand Down Expand Up @@ -317,6 +317,8 @@ class HybridParallelPlugin(PipelinePluginBase):
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
"""

def __init__(
Expand Down Expand Up @@ -352,6 +354,8 @@ def __init__(
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
) -> None:
super().__init__()
assert (
Expand All @@ -376,22 +380,37 @@ def __init__(
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
self.pp_style = pp_style
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
self.stage_manager = PipelineStageManager(
self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=num_model_chunks
)

if self.pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
num_model_chunks=num_model_chunks,
)
else:
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
scheduler=self.schedule,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
Expand Down
78 changes: 55 additions & 23 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import torch.cuda
from torch.nn import Module
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map

from colossalai.interface import OptimizerWrapper
Expand All @@ -16,18 +16,25 @@


class InterleavedSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
self.num_model_chunks = num_model_chunks
assert (
num_microbatches % self.num_model_chunks == 0
), "Number of microbatches should be an integer multiple of number of model chunks"
def __init__(
self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
num_model_chunks: Optional[int] = 1,
) -> None:
super().__init__(stage_manager)
assert (
num_microbatches is not None or microbatch_size is not None
), "Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
self.num_model_chunks = num_model_chunks

def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Expand All @@ -42,8 +49,22 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
if not self._use_microbatch_size:
assert (
self.batch_size % self.num_microbatches == 0
), "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size

assert (
self.num_microbatches % self.num_model_chunks == 0
), "Number of microbatches should be an integer multiple of number of model chunks"

assert (
self.num_microbatches % self.stage_manager.num_stages == 0
), "Number of microbatches should be an integer multiple of number of pipeline parallel devices"

def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Expand Down Expand Up @@ -83,7 +104,7 @@ def is_first_stage(self, model_chunk_id: int) -> bool:
Returns:
bool: Whether the current virtual stage is the first stage.
"""
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
if self.stage_manager.is_first_device() and model_chunk_id == 0:
return True
return False

Expand All @@ -96,7 +117,7 @@ def is_last_stage(self, model_chunk_id: int) -> bool:
Returns:
bool: Whether the current virtual stage is the last stage.
"""
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
if self.stage_manager.is_last_device() and model_chunk_id == self.num_model_chunks - 1:
return True
return False

Expand Down Expand Up @@ -162,7 +183,7 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None

def forward_step(
self,
model_chunk: Module,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable,
Expand All @@ -171,7 +192,7 @@ def forward_step(
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (Module): Model Chunk to be run
model (ModuleList or Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
Expand All @@ -184,7 +205,15 @@ def forward_step(

# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)

if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
else:
Comment thread
ver217 marked this conversation as resolved.
# in shardformer, each device still has the entire model, so we need to pass the relevant stage layers
if input_obj is None:
input_obj = {}
input_obj["stage_index"] = self.stage_manager.layers[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, input_obj)

if self.is_last_stage(model_chunk_id):
loss = criterion(output_obj, micro_batch) / self.num_microbatches
Expand Down Expand Up @@ -243,7 +272,7 @@ def backward_step(

def forward_backward_step(
self,
model_chunk: Module,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
Expand All @@ -253,7 +282,7 @@ def forward_backward_step(
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.

Args:
model_chunk (List[Module]): Model Chunk to be trained.
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
Expand All @@ -268,7 +297,7 @@ def forward_backward_step(
assert forward_only, "Optimizer should be passed when doing backward."

self.load_batch(data_iter)
num_model_chunks = len(model_chunk)
num_model_chunks = self.num_model_chunks

# num_warmup_microbatches is the step when not all the processes are working
num_microbatches = self.num_microbatches * num_model_chunks
Expand All @@ -289,23 +318,25 @@ def forward_backward_step(
input_objs = [[] for _ in range(num_model_chunks)]
output_objs = [[] for _ in range(num_model_chunks)]

outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
outputs = [] if return_outputs and self.stage_manager.is_last_device() else None

if return_loss and self.stage_manager.is_last_stage():
if return_loss and self.stage_manager.is_last_device():
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None

# for ranks except the first one, get into recv state
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
input_obj = self.recv_forward(0)
input_objs[0].append(input_obj)
if not forward_only:
input_objs[0].append(input_obj)

# Run warmup forward passes.
for i in range(num_warmup_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=True)

self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id)
# recv first on first rank to avoid sending or recving at the same time
if self.stage_manager.is_first_stage():
if self.stage_manager.is_first_device():
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
Expand All @@ -329,6 +360,7 @@ def forward_backward_step(
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id)
last_iteration = i == (num_microbatches_remaining - 1)

output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
Expand Down Expand Up @@ -367,10 +399,10 @@ def forward_backward_step(
if not forward_only:
for i in range(num_microbatches_remaining, num_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=False)
self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id)
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)

output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
Expand Down
41 changes: 39 additions & 2 deletions colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class PipelineStageManager:
stage (int): The current stage.
"""

def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
def __init__(
self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False, num_model_chunks=1
Comment thread
ver217 marked this conversation as resolved.
) -> None:
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
Expand All @@ -32,6 +34,10 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo
# the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks = num_model_chunks
Comment thread
ver217 marked this conversation as resolved.
self.model_chunk_id = 0
self.layers = Optional[List[List[int]]]

# init p2p process groups
stages = list(range(self.num_stages))
Expand All @@ -55,14 +61,31 @@ def is_first_stage(self) -> bool:
Returns:
bool: Whether the current stage is the first stage.
"""
return self.stage == 0
return self.stage == 0 and self.model_chunk_id == 0

def is_last_stage(self) -> bool:
"""Is the current stage the last stage.

Returns:
bool: Whether the current stage is the last stage.
"""
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1

# introduced due to interleaved pipeline parallel, as the first/last device may also hold intermediate stages
def is_first_device(self) -> bool:
"""Is the current stage on the first device.

Returns:
bool: Whether the current stage is on the first device.
"""
return self.stage == 0

def is_last_device(self) -> bool:
"""Is the current stage on the last device.

Returns:
bool: Whether the current stage on the last device.
"""
return self.stage == self.num_stages - 1

@property
Expand Down Expand Up @@ -131,3 +154,17 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
ProcessGroup: Process group of the given stages.
"""
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)

def set_interleaved_model_chunk_id(self, model_chunk_id: int):
"""For interleaved pipeline parallel, set the model chunk id for the device at the current stage.
Args:
model_chunk_id (int): the id of the current model chunk for the device.
"""
self.model_chunk_id = model_chunk_id

def set_interleaved_device_layers(self, layers: List[List[int]]):
"""For interleaved pipeline parallel, set the layer chunks for the device.
Args:
layers (List[List[int]]): list of layer chunks for the device.
"""
self.layers = layers
Loading