From dde39a4fb33456739f8fc4ba8cb05cdd8f1b335a Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Tue, 26 Sep 2023 15:01:52 +0800 Subject: [PATCH 1/7] [shardformer] support interleaved pipeline parallel for bert finetune (hard coded) --- .../booster/plugin/hybrid_parallel_plugin.py | 21 ++++++-- .../pipeline/schedule/interleaved_pp.py | 16 +++++-- colossalai/shardformer/modeling/bert.py | 29 +++++++---- colossalai/shardformer/policies/bert.py | 48 ++++++++++++++++++- examples/language/bert/finetune.py | 10 ++-- 5 files changed, 102 insertions(+), 22 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 479ccc3eb36e..63a41e0cdeaf 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -20,7 +20,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy @@ -376,16 +376,29 @@ def __init__( self.stage_manager = None self.schedule = None self.custom_policy = custom_policy + self.num_microbatches = num_microbatches + print("num_microbatches: ", self.num_microbatches) + print("micro_batch_size: ", microbatch_size) assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS, is_virtual=True) + """ + self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) + """ + self.schedule = InterleavedSchedule( + num_microbatches=num_microbatches, + # microbatch_size=microbatch_size, + num_model_chunks=2, + stage_manager=self.stage_manager, ) + #''' + # raise Exception self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 780437155c61..53c977bf030e 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -163,6 +163,7 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None def forward_step( self, model_chunk: Module, + # layers: Optional[List[List[int]]], model_chunk_id: int, input_obj: Optional[dict], criterion: Callable, @@ -184,7 +185,11 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + if input_obj is None: + input_obj = {} + input_obj["model_chunk_id"] = model_chunk_id + # output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + output_obj = model_forward(model_chunk, micro_batch, input_obj) if self.is_last_stage(model_chunk_id): loss = criterion(output_obj, micro_batch) / self.num_microbatches @@ -243,13 +248,15 @@ def backward_step( def forward_backward_step( self, - model_chunk: Module, + model_chunk, data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, ) -> dict: + # self.xxx = func(model.config.xx, stage + """Runs interleaved 1F1B schedule, with communication between pipeline stages. Args: @@ -268,7 +275,8 @@ def forward_backward_step( assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) - num_model_chunks = len(model_chunk) + # raise Exception + num_model_chunks = 2 # num_warmup_microbatches is the step when not all the processes are working num_microbatches = self.num_microbatches * num_model_chunks @@ -303,10 +311,10 @@ def forward_backward_step( # Run warmup forward passes. for i in range(num_warmup_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=True) - # recv first on first rank to avoid sending or recving at the same time if self.stage_manager.is_first_stage(): input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) self.send_forward(model_chunk_id, output_obj) if not forward_only: diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 7411e1d0ec46..e3824db7788c 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -57,8 +57,10 @@ def bert_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - stage_index: Optional[List[int]] = None, + # stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + layers: Optional[List[List[int]]] = None, + model_chunk_id: Optional[int] = None, ): # TODO(jianghai): add explaination of the output here. r""" @@ -94,7 +96,9 @@ def bert_model_forward( else: use_cache = False - if stage_manager.is_first_stage(): + # get stage index based on assigned layers and chunk id + stage_index = layers[model_chunk_id] + if stage_manager.is_first_stage() and model_chunk_id == 0: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -157,7 +161,7 @@ def bert_model_forward( head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) hidden_states = hidden_states if hidden_states is not None else None - if stage_manager.is_first_stage(): + if stage_manager.is_first_stage() and model_chunk_id == 0: hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, @@ -179,7 +183,10 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None + # retrieval stage from multiple stages + # stage_index = multiple_stage_index[model_chunk_id] start_idx, end_idx = stage_index[0], stage_index[1] + # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None @@ -195,7 +202,7 @@ def bert_model_forward( ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): - if stage_manager.is_first_stage() and idx == 0: + if stage_manager.is_first_stage() and model_chunk_id == 0 and idx == 0: encoder_attention_mask = encoder_extended_attention_mask if output_hidden_states: @@ -250,7 +257,7 @@ def custom_forward(*inputs): # end of a stage loop sequence_output = hidden_states if hidden_states is not None else None - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage() and model_chunk_id == 1: pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] @@ -660,8 +667,12 @@ def bert_for_sequence_classification_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, - stage_index: Optional[List[int]] = None, + # stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + model_chunk_id: Optional[int] = None, + layers: Optional[List[List[int]]] = None, + # num_chunks: int = None, + # model_chunk_id: int = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -693,11 +704,13 @@ def bert_for_sequence_classification_forward( return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index, + # stage_index=stage_index, shard_config=shard_config, + layers=layers, + model_chunk_id=model_chunk_id, ) - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage() and model_chunk_id == 1: pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 14146de158ae..d4b2b00b1a07 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -250,14 +250,44 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model else: module = self.model.bert + #''' + # interleaved + num_chunks = 2 + layers_per_stage = Policy.distribute_layers( + len(module.encoder.layer), stage_manager.num_stages * num_chunks + ) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + # [0,3],[6,9] + + multiple_stage_index = [] + multiple_stage_index.append(stage_index) + if stage_index[0] == 0: + multiple_stage_index.append([6, 9]) + else: + multiple_stage_index.append([9, 12]) + + print("multiple stages added") + print(multiple_stage_index) + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + layers=multiple_stage_index, + shard_config=self.shard_config, + ) + } + """ + # 1f1b layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, stage_manager=stage_manager, layers=stage_index, shard_config=self.shard_config ) } + """ self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) @@ -275,11 +305,27 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] + #''' + # interleaved + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages * 2) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + # raise Exception + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + held_layers.extend(module.encoder.layer[start_idx + 6 : end_idx + 6]) + + """ + + #1f1b layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + #raise Exception held_layers.extend(module.encoder.layer[start_idx:end_idx]) + """ + if stage_manager.is_last_stage(): held_layers.append(module.pooler) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 563cfa58d5f6..781368409088 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -221,12 +221,12 @@ def main(): plugin = HybridParallelPlugin( tp_size=1, pp_size=2, - num_microbatches=None, + num_microbatches=2, microbatch_size=1, - enable_all_optimization=True, - zero_stage=1, - precision="fp16", - initial_scale=1, + # enable_all_optimization=False, + # zero_stage=1, + # precision="fp16", + # initial_scale=1, ) booster = Booster(plugin=plugin, **booster_kwargs) From eb496e0f22fc3e7269df70ef0aba7aadf7d936c4 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Tue, 26 Sep 2023 18:05:40 +0800 Subject: [PATCH 2/7] [shardformer] generlize interleave solution for non bert model 26 Sep --- .../booster/plugin/hybrid_parallel_plugin.py | 34 ++++++++++------- .../pipeline/schedule/interleaved_pp.py | 23 ++++++----- colossalai/pipeline/stage_manager.py | 6 ++- .../shardformer/policies/base_policy.py | 38 ++++++++++++++----- colossalai/shardformer/policies/bert.py | 9 +++-- examples/language/bert/finetune.py | 4 +- 6 files changed, 73 insertions(+), 41 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 63a41e0cdeaf..acf06e52742c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -20,7 +20,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.pipeline.schedule import InterleavedSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy @@ -352,6 +352,8 @@ def __init__( communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, ) -> None: super().__init__() assert ( @@ -377,26 +379,30 @@ def __init__( self.schedule = None self.custom_policy = custom_policy self.num_microbatches = num_microbatches - print("num_microbatches: ", self.num_microbatches) - print("micro_batch_size: ", microbatch_size) + self.pp_style = pp_style + self.num_model_chunks = num_model_chunks assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS, is_virtual=True) - """ - self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size) - """ - self.schedule = InterleavedSchedule( - num_microbatches=num_microbatches, - # microbatch_size=microbatch_size, - num_model_chunks=2, - stage_manager=self.stage_manager, + self.stage_manager = PipelineStageManager( + self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=self.num_model_chunks ) + + if self.pp_style == "interleaved": + assert self.num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + num_microbatches=num_microbatches, + num_model_chunks=num_model_chunks, + stage_manager=self.stage_manager, + ) + else: + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) + #''' # raise Exception self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 53c977bf030e..e2daa60e06ea 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -3,7 +3,7 @@ import torch import torch.cuda -from torch.nn import Module +from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper @@ -162,8 +162,7 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None def forward_step( self, - model_chunk: Module, - # layers: Optional[List[List[int]]], + model_chunk: Union[ModuleList, Module], model_chunk_id: int, input_obj: Optional[dict], criterion: Callable, @@ -172,7 +171,7 @@ def forward_step( ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: - model (Module): Model Chunk to be run + model (ModuleList or Module): Model Chunk to be run input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. criterion (Callable): Criterion to calculate loss. accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. @@ -188,8 +187,11 @@ def forward_step( if input_obj is None: input_obj = {} input_obj["model_chunk_id"] = model_chunk_id - # output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) - output_obj = model_forward(model_chunk, micro_batch, input_obj) + + if isinstance(model_chunk, ModuleList): + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + else: + output_obj = model_forward(model_chunk, micro_batch, input_obj) if self.is_last_stage(model_chunk_id): loss = criterion(output_obj, micro_batch) / self.num_microbatches @@ -248,19 +250,17 @@ def backward_step( def forward_backward_step( self, - model_chunk, + model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, ) -> dict: - # self.xxx = func(model.config.xx, stage - """Runs interleaved 1F1B schedule, with communication between pipeline stages. Args: - model_chunk (List[Module]): Model Chunk to be trained. + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification data_iter (Iterable): Data iterator. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. @@ -275,8 +275,7 @@ def forward_backward_step( assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) - # raise Exception - num_model_chunks = 2 + num_model_chunks = self.num_model_chunks # num_warmup_microbatches is the step when not all the processes are working num_microbatches = self.num_microbatches * num_model_chunks diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b79867a2c651..4d3c6a0ebcf7 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -18,7 +18,9 @@ class PipelineStageManager: stage (int): The current stage. """ - def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: + def __init__( + self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False, num_model_chunks=1 + ) -> None: self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None @@ -32,6 +34,8 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") + # number of layer chunks in each stage for interleaved pipeline, with each device has non-discontinuous layers + self.num_model_chunks = num_model_chunks # init p2p process groups stages = list(range(self.num_stages)) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e7f199129a00..1a8f7058f5d7 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -214,13 +214,31 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: return layers_per_stage @staticmethod - def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: - """ - get the start index and end index of layers for each stage. - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - start_idx = num_layers_per_stage_accumulated[stage] - end_idx = num_layers_per_stage_accumulated[stage + 1] - - return [start_idx, end_idx] + def get_stage_index( + layers_per_stage: List[int], stage: int, num_stages=None, num_model_chunks=1 + ) -> Union[List[int], List[List[int]]]: + # [6, 6] stage 0 + # [3,3,3,3] stage 0 + if num_stages is None: + """ + get the start index and end index of layers for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[stage] + end_idx = num_layers_per_stage_accumulated[stage + 1] + + return [start_idx, end_idx] + else: + """ + interleaved pipeline: get the start index and end index pairs for each stage's layers. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + stage_indexes = [] + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] + stage_indexes.append([start_idx, end_idx]) + + return stage_indexes diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index d4b2b00b1a07..777b56fbb8b2 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -250,12 +250,15 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model else: module = self.model.bert - #''' - # interleaved - num_chunks = 2 + + # num_chunks > 1 if interleaved + num_chunks = stage_manager.num_model_chunks layers_per_stage = Policy.distribute_layers( len(module.encoder.layer), stage_manager.num_stages * num_chunks ) + print("***layers per stage***") + print(layers_per_stage) + raise Exception("assigning layers") stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) # [0,3],[6,9] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 781368409088..3e123c6502a3 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -222,7 +222,9 @@ def main(): tp_size=1, pp_size=2, num_microbatches=2, - microbatch_size=1, + pp_style="interleaved", + num_model_chunks=2, + # microbatch_size=1, # enable_all_optimization=False, # zero_stage=1, # precision="fp16", From f631c3e01c118213526fac854d2bb2b61f87a0ce Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Wed, 27 Sep 2023 13:59:36 +0800 Subject: [PATCH 3/7] [shardformer] interleaved pipeline parallel for bert fine tune example --- .../booster/plugin/hybrid_parallel_plugin.py | 3 +- .../pipeline/schedule/interleaved_pp.py | 43 ++++++++--- colossalai/pipeline/stage_manager.py | 2 +- colossalai/shardformer/modeling/bert.py | 26 +++---- .../shardformer/policies/base_policy.py | 5 +- colossalai/shardformer/policies/bert.py | 74 +++++++------------ examples/language/bert/finetune.py | 16 ++-- .../test_schedule/test_interleaved.py | 11 ++- 8 files changed, 95 insertions(+), 85 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index acf06e52742c..6ca12801307e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -394,9 +394,10 @@ def __init__( if self.pp_style == "interleaved": assert self.num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, num_microbatches=num_microbatches, + microbatch_size=microbatch_size, num_model_chunks=num_model_chunks, - stage_manager=self.stage_manager, ) else: self.schedule = OneForwardOneBackwardSchedule( diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index e2daa60e06ea..1fc4a2e3422f 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -16,18 +16,25 @@ class InterleavedSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: - self.num_model_chunks = num_model_chunks - assert ( - num_microbatches % self.num_model_chunks == 0 - ), "Number of microbatches should be an integer multiple of number of model chunks" + def __init__( + self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + num_model_chunks: Optional[int] = 1, + ) -> None: super().__init__(stage_manager) + assert ( + num_microbatches is not None or microbatch_size is not None + ), "Either num_microbatches or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches + self.microbatch_size = microbatch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self.microbatch_size: Optional[int] = None + self._use_microbatch_size = num_microbatches is None + self.num_model_chunks = num_model_chunks def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -42,8 +49,22 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches + if not self._use_microbatch_size: + assert ( + self.batch_size % self.num_microbatches == 0 + ), "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + else: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatches = self.batch_size // self.microbatch_size + + assert ( + self.num_microbatches % self.num_model_chunks == 0 + ), "Number of microbatches should be an integer multiple of number of model chunks" + + assert ( + self.num_microbatches % self.stage_manager.num_stages == 0 + ), "Number of microbatches should be an integer multiple of number of pipeline parallel devices" def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. @@ -184,6 +205,8 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + + # in shardformer, each device still has the entire model, so we need to pass the model_chunk_id to replaced forward if input_obj is None: input_obj = {} input_obj["model_chunk_id"] = model_chunk_id @@ -306,7 +329,9 @@ def forward_backward_step( # for ranks except the first one, get into recv state # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) input_obj = self.recv_forward(0) - input_objs[0].append(input_obj) + if not forward_only: + input_objs[0].append(input_obj) + # Run warmup forward passes. for i in range(num_warmup_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=True) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 4d3c6a0ebcf7..18740c13312c 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -34,7 +34,7 @@ def __init__( # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") - # number of layer chunks in each stage for interleaved pipeline, with each device has non-discontinuous layers + # number of layer chunks in each stage for interleaved pipeline, with each device has discontinuous layers self.num_model_chunks = num_model_chunks # init p2p process groups diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index e3824db7788c..1cc9062924ef 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -57,10 +57,9 @@ def bert_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - # stage_index: Optional[List[int]] = None, + stage_index: Union[Optional[List[int]], Optional[List[List[int]]]] = None, shard_config: ShardConfig = None, - layers: Optional[List[List[int]]] = None, - model_chunk_id: Optional[int] = None, + model_chunk_id: int = 0, ): # TODO(jianghai): add explaination of the output here. r""" @@ -96,8 +95,9 @@ def bert_model_forward( else: use_cache = False - # get stage index based on assigned layers and chunk id - stage_index = layers[model_chunk_id] + # if interleaved, get stage index from a list of stages based on chunk id + if stage_index is not None and all(isinstance(item, list) for item in stage_index): + stage_index = stage_index[model_chunk_id] if stage_manager.is_first_stage() and model_chunk_id == 0: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -183,8 +183,6 @@ def bert_model_forward( use_cache = False next_decoder_cache = () if use_cache else None - # retrieval stage from multiple stages - # stage_index = multiple_stage_index[model_chunk_id] start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs @@ -257,7 +255,7 @@ def custom_forward(*inputs): # end of a stage loop sequence_output = hidden_states if hidden_states is not None else None - if stage_manager.is_last_stage() and model_chunk_id == 1: + if stage_manager.is_last_stage() and model_chunk_id == stage_manager.num_model_chunks - 1: pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] @@ -667,12 +665,9 @@ def bert_for_sequence_classification_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, - # stage_index: Optional[List[int]] = None, + stage_index: Union[Optional[List[int]], Optional[List[List[int]]]] = None, shard_config: ShardConfig = None, - model_chunk_id: Optional[int] = None, - layers: Optional[List[List[int]]] = None, - # num_chunks: int = None, - # model_chunk_id: int = None, + model_chunk_id: int = 0, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -704,13 +699,12 @@ def bert_for_sequence_classification_forward( return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - # stage_index=stage_index, + stage_index=stage_index, shard_config=shard_config, - layers=layers, model_chunk_id=model_chunk_id, ) - if stage_manager.is_last_stage() and model_chunk_id == 1: + if stage_manager.is_last_stage() and model_chunk_id == stage_manager.num_model_chunks - 1: pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 1a8f7058f5d7..5c7c89b77c23 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -217,8 +217,7 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: def get_stage_index( layers_per_stage: List[int], stage: int, num_stages=None, num_model_chunks=1 ) -> Union[List[int], List[List[int]]]: - # [6, 6] stage 0 - # [3,3,3,3] stage 0 + # num_stages info is only needed for interleaved pipeline stage assignment if num_stages is None: """ get the start index and end index of layers for each stage. @@ -231,7 +230,7 @@ def get_stage_index( return [start_idx, end_idx] else: """ - interleaved pipeline: get the start index and end index pairs for each stage's layers. + interleaved pipeline: get the start index and end index PAIRS for each stage. """ num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 777b56fbb8b2..fcc0e0c5bd98 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -251,46 +251,25 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.bert - # num_chunks > 1 if interleaved - num_chunks = stage_manager.num_model_chunks + # num_model_chunks > 1 if interleaved + num_model_chunks = stage_manager.num_model_chunks layers_per_stage = Policy.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * num_chunks + len(module.encoder.layer), stage_manager.num_stages * num_model_chunks ) - print("***layers per stage***") - print(layers_per_stage) - raise Exception("assigning layers") - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - # [0,3],[6,9] - multiple_stage_index = [] - multiple_stage_index.append(stage_index) - if stage_index[0] == 0: - multiple_stage_index.append([6, 9]) + if num_model_chunks > 1: + stage_index = Policy.get_stage_index( + layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks + ) else: - multiple_stage_index.append([9, 12]) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - print("multiple stages added") - print(multiple_stage_index) method_replacement = { "forward": partial( - new_forward, - stage_manager=stage_manager, - layers=multiple_stage_index, - shard_config=self.shard_config, + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config ) } - """ - # 1f1b - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - - method_replacement = { - "forward": partial( - new_forward, stage_manager=stage_manager, layers=stage_index, shard_config=self.shard_config - ) - } - """ self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) @@ -308,26 +287,27 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - #''' - # interleaved - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages * 2) + num_model_chunks = stage_manager.num_model_chunks + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), stage_manager.num_stages * num_model_chunks + ) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - # raise Exception - held_layers.extend(module.encoder.layer[start_idx:end_idx]) - held_layers.extend(module.encoder.layer[start_idx + 6 : end_idx + 6]) - - """ + if num_model_chunks > 1: + stage_index = Policy.get_stage_index( + layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks + ) + else: + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - #1f1b - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - #raise Exception - held_layers.extend(module.encoder.layer[start_idx:end_idx]) - """ + # interleaved stage index for one device comes in pairs, e.g.[[0,3],[6,9]] + if all(isinstance(item, list) for item in stage_index): + for i in range(len(stage_index)): + start_idx, end_idx = stage_index[i] + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + else: + start_idx, end_idx = stage_index + held_layers.extend(module.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.pooler) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 3e123c6502a3..416826d9169b 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -64,11 +64,17 @@ def evaluate_subset(dataloader: DataLoader): batch = move_to_cuda(batch) labels = batch["labels"] if use_pipeline: + """skip the last batch with batch size 31 for interleaved pipeline parallel + as the number of microbatches needs to be a multiple of pipeline parallel devices + """ + if booster.plugin.pp_style == "interleaved" and len(labels) < 32: + continue pg_mesh = booster.plugin.pg_mesh pp_group = booster.plugin.pp_group current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) current_rank = dist.get_rank() batch = iter([batch]) + outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) if is_pp_last_stage: @@ -224,11 +230,11 @@ def main(): num_microbatches=2, pp_style="interleaved", num_model_chunks=2, - # microbatch_size=1, - # enable_all_optimization=False, - # zero_stage=1, - # precision="fp16", - # initial_scale=1, + microbatch_size=None, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, ) booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index f181453eaed5..567641c4df3f 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -66,7 +66,7 @@ def examine_pp(num_micro_batches): seed_all(1453) NUM_MICRO_BATCHS = num_micro_batches - BATCH_SIZE = num_micro_batches + BATCH_SIZE = 24 NUM_CHUNKS = 2 # create model @@ -76,8 +76,13 @@ def examine_pp(num_micro_batches): DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(1, world_size, 1) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) - schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True, num_model_chunks=NUM_CHUNKS) + schedule = InterleavedSchedule( + stage_manager=stage_manager, + num_microbatches=NUM_MICRO_BATCHS, + microbatch_size=None, + num_model_chunks=NUM_CHUNKS, + ) sharded_model = torch.nn.ModuleList() for idx, (_, sub_model) in enumerate(pp_model.named_children()): From ad33868e8b8d9359b1b7eca20da27583fa2d6f32 Mon Sep 17 00:00:00 2001 From: Pengtai Xu Date: Fri, 29 Sep 2023 17:37:28 +0800 Subject: [PATCH 4/7] [shardformer] refactor interleave implementation so that replaced forward fn in shardformer remain unchanged for all models --- .../booster/plugin/hybrid_parallel_plugin.py | 7 ++--- .../pipeline/schedule/interleaved_pp.py | 25 ++++++++--------- colossalai/pipeline/stage_manager.py | 20 +++++++++++++- colossalai/shardformer/modeling/bert.py | 21 +++++---------- .../shardformer/policies/base_policy.py | 7 +++++ colossalai/shardformer/policies/bert.py | 27 ++++++++++++------- colossalai/shardformer/shard/shard_config.py | 3 +++ examples/language/bert/finetune.py | 15 +++++------ .../test_schedule/test_interleaved.py | 2 +- 9 files changed, 79 insertions(+), 48 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6ca12801307e..b5374a0b607e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -317,6 +317,8 @@ class HybridParallelPlugin(PipelinePluginBase): communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. """ def __init__( @@ -403,15 +405,14 @@ def __init__( self.schedule = OneForwardOneBackwardSchedule( self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size ) - - #''' - # raise Exception self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, + scheduler=self.schedule, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 1fc4a2e3422f..2d63b4f11bb1 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -35,6 +35,7 @@ def __init__( self.microbatch_offset: Optional[int] = None self._use_microbatch_size = num_microbatches is None self.num_model_chunks = num_model_chunks + self.layers = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -104,7 +105,7 @@ def is_first_stage(self, model_chunk_id: int) -> bool: Returns: bool: Whether the current virtual stage is the first stage. """ - if self.stage_manager.is_first_stage() and model_chunk_id == 0: + if self.stage_manager.is_first_device() and model_chunk_id == 0: return True return False @@ -117,7 +118,7 @@ def is_last_stage(self, model_chunk_id: int) -> bool: Returns: bool: Whether the current virtual stage is the last stage. """ - if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: + if self.stage_manager.is_last_device() and model_chunk_id == self.num_model_chunks - 1: return True return False @@ -206,14 +207,13 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict - # in shardformer, each device still has the entire model, so we need to pass the model_chunk_id to replaced forward - if input_obj is None: - input_obj = {} - input_obj["model_chunk_id"] = model_chunk_id - if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) else: + # in shardformer, each device still has the entire model, so we need to pass the relevant stage layers + if input_obj is None: + input_obj = {} + input_obj["stage_index"] = self.layers[model_chunk_id] output_obj = model_forward(model_chunk, micro_batch, input_obj) if self.is_last_stage(model_chunk_id): @@ -319,9 +319,9 @@ def forward_backward_step( input_objs = [[] for _ in range(num_model_chunks)] output_objs = [[] for _ in range(num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + outputs = [] if return_outputs and self.stage_manager.is_last_device() else None - if return_loss and self.stage_manager.is_last_stage(): + if return_loss and self.stage_manager.is_last_device(): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None @@ -335,10 +335,10 @@ def forward_backward_step( # Run warmup forward passes. for i in range(num_warmup_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=True) + self.stage_manager.model_chunk_id = model_chunk_id # recv first on first rank to avoid sending or recving at the same time - if self.stage_manager.is_first_stage(): + if self.stage_manager.is_first_device(): input_obj = self.recv_forward(model_chunk_id) - output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) self.send_forward(model_chunk_id, output_obj) if not forward_only: @@ -361,6 +361,7 @@ def forward_backward_step( # Run 1F1B in steady state. for i in range(num_microbatches_remaining): model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) + self.stage_manager.model_chunk_id = model_chunk_id last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) @@ -399,10 +400,10 @@ def forward_backward_step( if not forward_only: for i in range(num_microbatches_remaining, num_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=False) + self.stage_manager.model_chunk_id = model_chunk_id # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") input_obj = input_objs[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = self.recv_backward(model_chunk_id) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(model_chunk_id, input_obj_grad) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 18740c13312c..8a623e853658 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -36,6 +36,7 @@ def __init__( self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") # number of layer chunks in each stage for interleaved pipeline, with each device has discontinuous layers self.num_model_chunks = num_model_chunks + self.model_chunk_id = 0 # init p2p process groups stages = list(range(self.num_stages)) @@ -59,7 +60,7 @@ def is_first_stage(self) -> bool: Returns: bool: Whether the current stage is the first stage. """ - return self.stage == 0 + return self.stage == 0 and self.model_chunk_id == 0 def is_last_stage(self) -> bool: """Is the current stage the last stage. @@ -67,6 +68,23 @@ def is_last_stage(self) -> bool: Returns: bool: Whether the current stage is the last stage. """ + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + + # introduced due to interleaved pipeline parallel, as the first/last device may also hold intermediate stages + def is_first_device(self) -> bool: + """Is the current stage on the first device. + + Returns: + bool: Whether the current stage is on the first device. + """ + return self.stage == 0 + + def is_last_device(self) -> bool: + """Is the current stage on the last device. + + Returns: + bool: Whether the current stage on the last device. + """ return self.stage == self.num_stages - 1 @property diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 1cc9062924ef..7411e1d0ec46 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -57,9 +57,8 @@ def bert_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage - stage_index: Union[Optional[List[int]], Optional[List[List[int]]]] = None, + stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - model_chunk_id: int = 0, ): # TODO(jianghai): add explaination of the output here. r""" @@ -95,10 +94,7 @@ def bert_model_forward( else: use_cache = False - # if interleaved, get stage index from a list of stages based on chunk id - if stage_index is not None and all(isinstance(item, list) for item in stage_index): - stage_index = stage_index[model_chunk_id] - if stage_manager.is_first_stage() and model_chunk_id == 0: + if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -161,7 +157,7 @@ def bert_model_forward( head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) hidden_states = hidden_states if hidden_states is not None else None - if stage_manager.is_first_stage() and model_chunk_id == 0: + if stage_manager.is_first_stage(): hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, @@ -184,7 +180,6 @@ def bert_model_forward( next_decoder_cache = () if use_cache else None start_idx, end_idx = stage_index[0], stage_index[1] - # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None @@ -200,7 +195,7 @@ def bert_model_forward( ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): - if stage_manager.is_first_stage() and model_chunk_id == 0 and idx == 0: + if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask if output_hidden_states: @@ -255,7 +250,7 @@ def custom_forward(*inputs): # end of a stage loop sequence_output = hidden_states if hidden_states is not None else None - if stage_manager.is_last_stage() and model_chunk_id == stage_manager.num_model_chunks - 1: + if stage_manager.is_last_stage(): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + layer_outputs[1:] @@ -665,9 +660,8 @@ def bert_for_sequence_classification_forward( return_dict: Optional[bool] = None, hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, - stage_index: Union[Optional[List[int]], Optional[List[List[int]]]] = None, + stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - model_chunk_id: int = 0, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -701,10 +695,9 @@ def bert_for_sequence_classification_forward( stage_manager=stage_manager, stage_index=stage_index, shard_config=shard_config, - model_chunk_id=model_chunk_id, ) - if stage_manager.is_last_stage() and model_chunk_id == stage_manager.num_model_chunks - 1: + if stage_manager.is_last_stage(): pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 5c7c89b77c23..789565e5a010 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -9,6 +9,7 @@ from torch import Tensor from torch.nn import Module +from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from ..layer.parallel_module import ParallelModule @@ -99,6 +100,12 @@ def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: return self.shard_config.pipeline_stage_manager return None + @property + def scheduler(self) -> Optional[PipelineSchedule]: + if self.shard_config is not None: + return self.shard_config.scheduler + return None + @abstractmethod def config_sanity_check(self): """ diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fcc0e0c5bd98..689b0b659ff1 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -264,12 +264,21 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config - ) - } - + if num_model_chunks == 1: + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, + ) + } + # for interleaved, stage index for each forward is chosen in scheduler + else: + self.scheduler.layers = stage_index + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) @@ -291,7 +300,7 @@ def get_held_layers(self) -> List[Module]: layers_per_stage = self.distribute_layers( len(module.encoder.layer), stage_manager.num_stages * num_model_chunks ) - if stage_manager.is_first_stage(): + if stage_manager.is_first_device(): held_layers.append(module.embeddings) if num_model_chunks > 1: stage_index = Policy.get_stage_index( @@ -309,7 +318,7 @@ def get_held_layers(self) -> List[Module]: start_idx, end_idx = stage_index held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(): + if stage_manager.is_last_device(): held_layers.append(module.pooler) return held_layers @@ -501,7 +510,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_device(): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a285874d218b..ca9a17784013 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -4,6 +4,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup +from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager __all__ = ["ShardConfig"] @@ -17,6 +18,7 @@ class ShardConfig: Args: tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + scheduler (Optional[PipelineSchedule]): If using interleaved pp, it's necessary to specify the scheduler for layer assignment for each device. enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. @@ -28,6 +30,7 @@ class ShardConfig: """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None + scheduler: Optional[PipelineSchedule] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_flash_attention: bool = False diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 416826d9169b..720a4e6699d9 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,7 +57,7 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: @@ -77,7 +77,7 @@ def evaluate_subset(dataloader: DataLoader): outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) - if is_pp_last_stage: + if is_pp_last_device: logits = outputs["outputs"]["logits"] val_loss = outputs["loss"] accum_loss.add_(val_loss) @@ -93,7 +93,6 @@ def evaluate_subset(dataloader: DataLoader): elif current_rank in current_pp_group_ranks: object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) accum_loss.add_(object_list[1].to(get_current_device())) @@ -139,8 +138,8 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) model.train() @@ -154,7 +153,7 @@ def train_epoch( train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True ) # Backward and optimize - if is_pp_last_stage: + if is_pp_last_device: loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: @@ -227,10 +226,10 @@ def main(): plugin = HybridParallelPlugin( tp_size=1, pp_size=2, - num_microbatches=2, + num_microbatches=None, pp_style="interleaved", num_model_chunks=2, - microbatch_size=None, + microbatch_size=1, enable_all_optimization=True, zero_stage=1, precision="fp16", diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 567641c4df3f..989f88c787b5 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -120,7 +120,7 @@ def examine_pp(num_micro_batches): ) # check loss - if stage_manager.is_last_stage(): + if stage_manager.is_last_device(): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients From de53d0e8a530c187db1cec449c1a39494371a6fd Mon Sep 17 00:00:00 2001 From: ppt0011 Date: Wed, 11 Oct 2023 16:58:42 +0800 Subject: [PATCH 5/7] [shardformer] move layer attr to stage manager, and style changes --- .../booster/plugin/hybrid_parallel_plugin.py | 6 ++---- colossalai/pipeline/schedule/interleaved_pp.py | 9 ++++----- colossalai/pipeline/stage_manager.py | 17 ++++++++++++++++- colossalai/shardformer/policies/bert.py | 2 +- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b5374a0b607e..9cc0d74b3556 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -380,9 +380,7 @@ def __init__( self.stage_manager = None self.schedule = None self.custom_policy = custom_policy - self.num_microbatches = num_microbatches self.pp_style = pp_style - self.num_model_chunks = num_model_chunks assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert ( @@ -390,11 +388,11 @@ def __init__( ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( - self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=self.num_model_chunks + self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=num_model_chunks ) if self.pp_style == "interleaved": - assert self.num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" self.schedule = InterleavedSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 2d63b4f11bb1..4e1286448589 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -35,7 +35,6 @@ def __init__( self.microbatch_offset: Optional[int] = None self._use_microbatch_size = num_microbatches is None self.num_model_chunks = num_model_chunks - self.layers = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -213,7 +212,7 @@ def forward_step( # in shardformer, each device still has the entire model, so we need to pass the relevant stage layers if input_obj is None: input_obj = {} - input_obj["stage_index"] = self.layers[model_chunk_id] + input_obj["stage_index"] = self.stage_manager.layers[model_chunk_id] output_obj = model_forward(model_chunk, micro_batch, input_obj) if self.is_last_stage(model_chunk_id): @@ -335,7 +334,7 @@ def forward_backward_step( # Run warmup forward passes. for i in range(num_warmup_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=True) - self.stage_manager.model_chunk_id = model_chunk_id + self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) # recv first on first rank to avoid sending or recving at the same time if self.stage_manager.is_first_device(): input_obj = self.recv_forward(model_chunk_id) @@ -361,7 +360,7 @@ def forward_backward_step( # Run 1F1B in steady state. for i in range(num_microbatches_remaining): model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) - self.stage_manager.model_chunk_id = model_chunk_id + self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) @@ -400,7 +399,7 @@ def forward_backward_step( if not forward_only: for i in range(num_microbatches_remaining, num_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=False) - self.stage_manager.model_chunk_id = model_chunk_id + self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") input_obj = input_objs[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 8a623e853658..63ac0373d543 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -34,9 +34,10 @@ def __init__( # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") - # number of layer chunks in each stage for interleaved pipeline, with each device has discontinuous layers + # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks = num_model_chunks self.model_chunk_id = 0 + self.layers = Optional[List[List[int]]] # init p2p process groups stages = list(range(self.num_stages)) @@ -153,3 +154,17 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: ProcessGroup: Process group of the given stages. """ return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) + + def set_interleaved_model_chunk_id(self, model_chunk_id: int): + """For interleaved pipeline parallel, set the model chunk id for the device at the current stage. + Args: + model_chunk_id (int): the id of the current model chunk for the device. + """ + self.model_chunk_id = model_chunk_id + + def set_interleaved_device_layers(self, layers: List[List[int]]): + """For interleaved pipeline parallel, set the layer chunks for the device. + Args: + layers (List[List[int]]): list of layer chunks for the device. + """ + self.layers = layers diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 689b0b659ff1..228489ea9aca 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -275,7 +275,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli } # for interleaved, stage index for each forward is chosen in scheduler else: - self.scheduler.layers = stage_index + stage_manager.set_interleaved_device_layers(stage_index) method_replacement = { "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) } From 013fb01b8f1c3741d8cc474f857a5c2c1a215048 Mon Sep 17 00:00:00 2001 From: ppt0011 Date: Tue, 31 Oct 2023 15:01:55 +0800 Subject: [PATCH 6/7] [shardformer] increase micro batch size due to convergence issue --- examples/language/bert/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 720a4e6699d9..8ab3eeb3110c 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -229,7 +229,7 @@ def main(): num_microbatches=None, pp_style="interleaved", num_model_chunks=2, - microbatch_size=1, + microbatch_size=16, enable_all_optimization=True, zero_stage=1, precision="fp16", From 53fe53cc2d535616d55550e93262f833d93680b4 Mon Sep 17 00:00:00 2001 From: ppt0011 Date: Wed, 1 Nov 2023 11:43:47 +0800 Subject: [PATCH 7/7] increase ci timeout time after discussion --- .github/workflows/example_check_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 5934704f4102..859b6e4fb556 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -79,7 +79,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 20 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} cancel-in-progress: true