diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 5934704f4102..859b6e4fb556 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -79,7 +79,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 20 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} cancel-in-progress: true diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 479ccc3eb36e..9cc0d74b3556 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -20,7 +20,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy @@ -317,6 +317,8 @@ class HybridParallelPlugin(PipelinePluginBase): communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. """ def __init__( @@ -352,6 +354,8 @@ def __init__( communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, ) -> None: super().__init__() assert ( @@ -376,22 +380,37 @@ def __init__( self.stage_manager = None self.schedule = None self.custom_policy = custom_policy + self.pp_style = pp_style assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + self.stage_manager = PipelineStageManager( + self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=num_model_chunks ) + + if self.pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + num_model_chunks=num_model_chunks, + ) + else: + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, + scheduler=self.schedule, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 780437155c61..4e1286448589 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -3,7 +3,7 @@ import torch import torch.cuda -from torch.nn import Module +from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper @@ -16,18 +16,25 @@ class InterleavedSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: - self.num_model_chunks = num_model_chunks - assert ( - num_microbatches % self.num_model_chunks == 0 - ), "Number of microbatches should be an integer multiple of number of model chunks" + def __init__( + self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + num_model_chunks: Optional[int] = 1, + ) -> None: super().__init__(stage_manager) + assert ( + num_microbatches is not None or microbatch_size is not None + ), "Either num_microbatches or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches + self.microbatch_size = microbatch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self.microbatch_size: Optional[int] = None + self._use_microbatch_size = num_microbatches is None + self.num_model_chunks = num_model_chunks def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -42,8 +49,22 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches + if not self._use_microbatch_size: + assert ( + self.batch_size % self.num_microbatches == 0 + ), "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + else: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatches = self.batch_size // self.microbatch_size + + assert ( + self.num_microbatches % self.num_model_chunks == 0 + ), "Number of microbatches should be an integer multiple of number of model chunks" + + assert ( + self.num_microbatches % self.stage_manager.num_stages == 0 + ), "Number of microbatches should be an integer multiple of number of pipeline parallel devices" def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. @@ -83,7 +104,7 @@ def is_first_stage(self, model_chunk_id: int) -> bool: Returns: bool: Whether the current virtual stage is the first stage. """ - if self.stage_manager.is_first_stage() and model_chunk_id == 0: + if self.stage_manager.is_first_device() and model_chunk_id == 0: return True return False @@ -96,7 +117,7 @@ def is_last_stage(self, model_chunk_id: int) -> bool: Returns: bool: Whether the current virtual stage is the last stage. """ - if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: + if self.stage_manager.is_last_device() and model_chunk_id == self.num_model_chunks - 1: return True return False @@ -162,7 +183,7 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None def forward_step( self, - model_chunk: Module, + model_chunk: Union[ModuleList, Module], model_chunk_id: int, input_obj: Optional[dict], criterion: Callable, @@ -171,7 +192,7 @@ def forward_step( ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: - model (Module): Model Chunk to be run + model (ModuleList or Module): Model Chunk to be run input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. criterion (Callable): Criterion to calculate loss. accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. @@ -184,7 +205,15 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + + if isinstance(model_chunk, ModuleList): + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + else: + # in shardformer, each device still has the entire model, so we need to pass the relevant stage layers + if input_obj is None: + input_obj = {} + input_obj["stage_index"] = self.stage_manager.layers[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, input_obj) if self.is_last_stage(model_chunk_id): loss = criterion(output_obj, micro_batch) / self.num_microbatches @@ -243,7 +272,7 @@ def backward_step( def forward_backward_step( self, - model_chunk: Module, + model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, @@ -253,7 +282,7 @@ def forward_backward_step( """Runs interleaved 1F1B schedule, with communication between pipeline stages. Args: - model_chunk (List[Module]): Model Chunk to be trained. + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification data_iter (Iterable): Data iterator. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. @@ -268,7 +297,7 @@ def forward_backward_step( assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) - num_model_chunks = len(model_chunk) + num_model_chunks = self.num_model_chunks # num_warmup_microbatches is the step when not all the processes are working num_microbatches = self.num_microbatches * num_model_chunks @@ -289,9 +318,9 @@ def forward_backward_step( input_objs = [[] for _ in range(num_model_chunks)] output_objs = [[] for _ in range(num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + outputs = [] if return_outputs and self.stage_manager.is_last_device() else None - if return_loss and self.stage_manager.is_last_stage(): + if return_loss and self.stage_manager.is_last_device(): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None @@ -299,13 +328,15 @@ def forward_backward_step( # for ranks except the first one, get into recv state # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) input_obj = self.recv_forward(0) - input_objs[0].append(input_obj) + if not forward_only: + input_objs[0].append(input_obj) + # Run warmup forward passes. for i in range(num_warmup_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=True) - + self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) # recv first on first rank to avoid sending or recving at the same time - if self.stage_manager.is_first_stage(): + if self.stage_manager.is_first_device(): input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) self.send_forward(model_chunk_id, output_obj) @@ -329,6 +360,7 @@ def forward_backward_step( # Run 1F1B in steady state. for i in range(num_microbatches_remaining): model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) + self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) @@ -367,10 +399,10 @@ def forward_backward_step( if not forward_only: for i in range(num_microbatches_remaining, num_microbatches): model_chunk_id = self.get_model_chunk_id(i, forward=False) + self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") input_obj = input_objs[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = self.recv_backward(model_chunk_id) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(model_chunk_id, input_obj_grad) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b79867a2c651..63ac0373d543 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -18,7 +18,9 @@ class PipelineStageManager: stage (int): The current stage. """ - def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: + def __init__( + self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False, num_model_chunks=1 + ) -> None: self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None @@ -32,6 +34,10 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") + # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers + self.num_model_chunks = num_model_chunks + self.model_chunk_id = 0 + self.layers = Optional[List[List[int]]] # init p2p process groups stages = list(range(self.num_stages)) @@ -55,7 +61,7 @@ def is_first_stage(self) -> bool: Returns: bool: Whether the current stage is the first stage. """ - return self.stage == 0 + return self.stage == 0 and self.model_chunk_id == 0 def is_last_stage(self) -> bool: """Is the current stage the last stage. @@ -63,6 +69,23 @@ def is_last_stage(self) -> bool: Returns: bool: Whether the current stage is the last stage. """ + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + + # introduced due to interleaved pipeline parallel, as the first/last device may also hold intermediate stages + def is_first_device(self) -> bool: + """Is the current stage on the first device. + + Returns: + bool: Whether the current stage is on the first device. + """ + return self.stage == 0 + + def is_last_device(self) -> bool: + """Is the current stage on the last device. + + Returns: + bool: Whether the current stage on the last device. + """ return self.stage == self.num_stages - 1 @property @@ -131,3 +154,17 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: ProcessGroup: Process group of the given stages. """ return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) + + def set_interleaved_model_chunk_id(self, model_chunk_id: int): + """For interleaved pipeline parallel, set the model chunk id for the device at the current stage. + Args: + model_chunk_id (int): the id of the current model chunk for the device. + """ + self.model_chunk_id = model_chunk_id + + def set_interleaved_device_layers(self, layers: List[List[int]]): + """For interleaved pipeline parallel, set the layer chunks for the device. + Args: + layers (List[List[int]]): list of layer chunks for the device. + """ + self.layers = layers diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e7f199129a00..789565e5a010 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -9,6 +9,7 @@ from torch import Tensor from torch.nn import Module +from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from ..layer.parallel_module import ParallelModule @@ -99,6 +100,12 @@ def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: return self.shard_config.pipeline_stage_manager return None + @property + def scheduler(self) -> Optional[PipelineSchedule]: + if self.shard_config is not None: + return self.shard_config.scheduler + return None + @abstractmethod def config_sanity_check(self): """ @@ -214,13 +221,30 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: return layers_per_stage @staticmethod - def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: - """ - get the start index and end index of layers for each stage. - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - start_idx = num_layers_per_stage_accumulated[stage] - end_idx = num_layers_per_stage_accumulated[stage + 1] - - return [start_idx, end_idx] + def get_stage_index( + layers_per_stage: List[int], stage: int, num_stages=None, num_model_chunks=1 + ) -> Union[List[int], List[List[int]]]: + # num_stages info is only needed for interleaved pipeline stage assignment + if num_stages is None: + """ + get the start index and end index of layers for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + start_idx = num_layers_per_stage_accumulated[stage] + end_idx = num_layers_per_stage_accumulated[stage + 1] + + return [start_idx, end_idx] + else: + """ + interleaved pipeline: get the start index and end index PAIRS for each stage. + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + stage_indexes = [] + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] + stage_indexes.append([start_idx, end_idx]) + + return stage_indexes diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 14146de158ae..228489ea9aca 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -251,13 +251,34 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.bert - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + # num_model_chunks > 1 if interleaved + num_model_chunks = stage_manager.num_model_chunks + layers_per_stage = Policy.distribute_layers( + len(module.encoder.layer), stage_manager.num_stages * num_model_chunks + ) + + if num_model_chunks > 1: + stage_index = Policy.get_stage_index( + layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks ) - } + else: + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + + if num_model_chunks == 1: + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, + ) + } + # for interleaved, stage index for each forward is chosen in scheduler + else: + stage_manager.set_interleaved_device_layers(stage_index) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) @@ -275,12 +296,29 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): + num_model_chunks = stage_manager.num_model_chunks + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), stage_manager.num_stages * num_model_chunks + ) + if stage_manager.is_first_device(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(): + if num_model_chunks > 1: + stage_index = Policy.get_stage_index( + layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks + ) + else: + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + + # interleaved stage index for one device comes in pairs, e.g.[[0,3],[6,9]] + if all(isinstance(item, list) for item in stage_index): + for i in range(len(stage_index)): + start_idx, end_idx = stage_index[i] + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + else: + start_idx, end_idx = stage_index + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + + if stage_manager.is_last_device(): held_layers.append(module.pooler) return held_layers @@ -472,7 +510,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_device(): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a285874d218b..ca9a17784013 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -4,6 +4,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup +from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager __all__ = ["ShardConfig"] @@ -17,6 +18,7 @@ class ShardConfig: Args: tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. + scheduler (Optional[PipelineSchedule]): If using interleaved pp, it's necessary to specify the scheduler for layer assignment for each device. enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. @@ -28,6 +30,7 @@ class ShardConfig: """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None + scheduler: Optional[PipelineSchedule] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_flash_attention: bool = False diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 563cfa58d5f6..8ab3eeb3110c 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,21 +57,27 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] if use_pipeline: + """skip the last batch with batch size 31 for interleaved pipeline parallel + as the number of microbatches needs to be a multiple of pipeline parallel devices + """ + if booster.plugin.pp_style == "interleaved" and len(labels) < 32: + continue pg_mesh = booster.plugin.pg_mesh pp_group = booster.plugin.pp_group current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) current_rank = dist.get_rank() batch = iter([batch]) + outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) - if is_pp_last_stage: + if is_pp_last_device: logits = outputs["outputs"]["logits"] val_loss = outputs["loss"] accum_loss.add_(val_loss) @@ -87,7 +93,6 @@ def evaluate_subset(dataloader: DataLoader): elif current_rank in current_pp_group_ranks: object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) accum_loss.add_(object_list[1].to(get_current_device())) @@ -133,8 +138,8 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) model.train() @@ -148,7 +153,7 @@ def train_epoch( train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True ) # Backward and optimize - if is_pp_last_stage: + if is_pp_last_device: loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: @@ -222,7 +227,9 @@ def main(): tp_size=1, pp_size=2, num_microbatches=None, - microbatch_size=1, + pp_style="interleaved", + num_model_chunks=2, + microbatch_size=16, enable_all_optimization=True, zero_stage=1, precision="fp16", diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index f181453eaed5..989f88c787b5 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -66,7 +66,7 @@ def examine_pp(num_micro_batches): seed_all(1453) NUM_MICRO_BATCHS = num_micro_batches - BATCH_SIZE = num_micro_batches + BATCH_SIZE = 24 NUM_CHUNKS = 2 # create model @@ -76,8 +76,13 @@ def examine_pp(num_micro_batches): DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(1, world_size, 1) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) - schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True, num_model_chunks=NUM_CHUNKS) + schedule = InterleavedSchedule( + stage_manager=stage_manager, + num_microbatches=NUM_MICRO_BATCHS, + microbatch_size=None, + num_model_chunks=NUM_CHUNKS, + ) sharded_model = torch.nn.ModuleList() for idx, (_, sub_model) in enumerate(pp_model.named_children()): @@ -115,7 +120,7 @@ def examine_pp(num_micro_batches): ) # check loss - if stage_manager.is_last_stage(): + if stage_manager.is_last_device(): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients