diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 9cc0d74b3556..e999c27d672f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -380,29 +380,38 @@ def __init__( self.stage_manager = None self.schedule = None self.custom_policy = custom_policy - self.pp_style = pp_style assert zero_stage in (0, 1, 2) if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( - self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=num_model_chunks + self.pg_mesh, + pipeline_axis=PP_AXIS, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks ) - if self.pp_style == "interleaved": + if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" self.schedule = InterleavedSchedule( stage_manager=self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size, num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, ) - else: + elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size ) + else: + raise NotImplementedError() + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) @@ -410,7 +419,6 @@ def __init__( self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, - scheduler=self.schedule, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 4e1286448589..5471419958e1 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -19,23 +19,23 @@ class InterleavedSchedule(PipelineSchedule): def __init__( self, stage_manager: PipelineStageManager, - num_microbatches: Optional[int] = None, + num_model_chunks: int, + num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, - num_model_chunks: Optional[int] = 1, ) -> None: super().__init__(stage_manager) assert ( - num_microbatches is not None or microbatch_size is not None - ), "Either num_microbatches or microbatch_size should be provided" + num_microbatch is not None or microbatch_size is not None + ), "Either num_microbatch or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) - self.num_microbatches = num_microbatches + self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size - self.batch: Optional[Any] = None - self.batch_size: Optional[int] = None - self.microbatch_offset: Optional[int] = None - self._use_microbatch_size = num_microbatches is None self.num_model_chunks = num_model_chunks + self.batch: Any + self.batch_size: int + self.microbatch_offset: List[int] + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -49,22 +49,25 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - if not self._use_microbatch_size: + if self.num_microbatch is not None: assert ( - self.batch_size % self.num_microbatches == 0 - ), "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches - else: + self.batch_size % self.num_microbatch == 0 + ), "Batch size should divided by the number of microbatch" + self.microbatch_size = self.batch_size // self.num_microbatch + elif self.microbatch_size is not None: assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" - self.num_microbatches = self.batch_size // self.microbatch_size + self.num_microbatch = self.batch_size // self.microbatch_size + else: + raise ValueError( + "Either num_microbatch or microbatch_size should be provided") assert ( - self.num_microbatches % self.num_model_chunks == 0 - ), "Number of microbatches should be an integer multiple of number of model chunks" + self.num_microbatch % self.num_model_chunks == 0 + ), "Number of microbatch should be an integer multiple of number of model chunks" assert ( - self.num_microbatches % self.stage_manager.num_stages == 0 - ), "Number of microbatches should be an integer multiple of number of pipeline parallel devices" + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. @@ -79,7 +82,7 @@ def load_micro_batch(self, model_chunk_id: int) -> Any: self.microbatch_offset[model_chunk_id] += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) - def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. Args: @@ -91,36 +94,10 @@ def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: """ microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages - if not forward: + if not is_forward: model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id - def is_first_stage(self, model_chunk_id: int) -> bool: - """Is the current virtual stage the first stage - - Args: - model_chunk_id (int): The current model chunk idx. - - Returns: - bool: Whether the current virtual stage is the first stage. - """ - if self.stage_manager.is_first_device() and model_chunk_id == 0: - return True - return False - - def is_last_stage(self, model_chunk_id: int) -> bool: - """Is the current virtual stage the last stage - - Args: - model_chunk_id (int): The current model chunk idx. - - Returns: - bool: Whether the current virtual stage is the last stage. - """ - if self.stage_manager.is_last_device() and model_chunk_id == self.num_model_chunks - 1: - return True - return False - def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For interleaved 1F1B. @@ -132,7 +109,7 @@ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - if self.is_first_stage(model_chunk_id): + if self.stage_manager.is_first_stage(model_chunk_id): input_tensor = None else: input_tensor = self.comm.recv_forward(prev_rank) @@ -150,7 +127,7 @@ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.is_last_stage(model_chunk_id): + if self.stage_manager.is_last_stage(model_chunk_id): output_tensor_grad = None else: output_tensor_grad = self.comm.recv_backward(next_rank) @@ -166,7 +143,7 @@ def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.is_last_stage(model_chunk_id): + if not self.stage_manager.is_last_stage(model_chunk_id): self.comm.send_forward(output_object, next_rank) def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: @@ -178,7 +155,7 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.is_first_stage(model_chunk_id): + if not self.stage_manager.is_first_stage(model_chunk_id): self.comm.send_backward(input_object, prev_rank) def forward_step( @@ -206,17 +183,18 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + self.stage_manager.model_chunk_id = model_chunk_id if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) else: - # in shardformer, each device still has the entire model, so we need to pass the relevant stage layers - if input_obj is None: - input_obj = {} - input_obj["stage_index"] = self.stage_manager.layers[model_chunk_id] - output_obj = model_forward(model_chunk, micro_batch, input_obj) - - if self.is_last_stage(model_chunk_id): - loss = criterion(output_obj, micro_batch) / self.num_microbatches + # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + self.stage_manager.model_chunk_id = None + + if self.stage_manager.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: @@ -279,7 +257,7 @@ def forward_backward_step( return_loss: bool = False, return_outputs: bool = False, ) -> dict: - """Runs interleaved 1F1B schedule, with communication between pipeline stages. + """Runs interleaved schedule, with communication between pipeline stages. Args: model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification @@ -292,51 +270,46 @@ def forward_backward_step( Returns: dict: A dict with keys: 'loss' and 'outputs'. """ + # TODO: handle arbitrary batch size when forward_only == True forward_only = not torch.is_grad_enabled() if optimizer is None: assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) - num_model_chunks = self.num_model_chunks - # num_warmup_microbatches is the step when not all the processes are working - num_microbatches = self.num_microbatches * num_model_chunks + num_microbatch = self.num_microbatch * self.num_model_chunks if forward_only: - num_warmup_microbatches = num_microbatches + num_warmup_microbatch = num_microbatch else: - num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) - num_microbatches_remaining = num_microbatches - num_warmup_microbatches + num_microbatch_remaining = num_microbatch - num_warmup_microbatch # Input, output tensors only need to be saved when doing backward passes input_objs = None output_objs = None if not forward_only: - input_objs = [[] for _ in range(num_model_chunks)] - output_objs = [[] for _ in range(num_model_chunks)] + input_objs = [[] for _ in range(self.num_model_chunks)] + output_objs = [[] for _ in range(self.num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_device() else None + outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None - if return_loss and self.stage_manager.is_last_device(): + if return_loss and self.stage_manager.is_last_stage(-1): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None # for ranks except the first one, get into recv state - # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) input_obj = self.recv_forward(0) - if not forward_only: - input_objs[0].append(input_obj) # Run warmup forward passes. - for i in range(num_warmup_microbatches): - model_chunk_id = self.get_model_chunk_id(i, forward=True) - self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) - # recv first on first rank to avoid sending or recving at the same time - if self.stage_manager.is_first_device(): + for i in range(num_warmup_microbatch): + model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + # recv first on first rank to avoid sending or receiving at the same time + if self.stage_manager.is_first_stage(-1): input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) self.send_forward(model_chunk_id, output_obj) @@ -346,22 +319,21 @@ def forward_backward_step( else: output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if not forward_only: + input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) self.send_forward(model_chunk_id, output_obj) - if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: - break - else: - model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) - input_obj = self.recv_forward(model_chunk_id) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) + if num_microbatch_remaining == 0 \ + and i + 1 == num_warmup_microbatch: + break + + model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) # Run 1F1B in steady state. - for i in range(num_microbatches_remaining): - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) - self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) - last_iteration = i == (num_microbatches_remaining - 1) + for i in range(num_microbatch_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + last_iteration = i == num_microbatch_remaining - 1 output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if forward_only: @@ -376,7 +348,7 @@ def forward_backward_step( input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) - model_chunk_id = self.get_model_chunk_id(i, forward=False) + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) output_obj_grad = self.recv_backward(model_chunk_id) # Pop output_obj and output_obj from the start of the list for @@ -390,23 +362,25 @@ def forward_backward_step( if last_iteration: input_obj = None else: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) input_obj = self.recv_forward(model_chunk_id) - model_chunk_id = self.get_model_chunk_id(i, forward=False) + + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) self.send_backward(model_chunk_id, input_obj_grad) # Run cooldown backward passes. if not forward_only: - for i in range(num_microbatches_remaining, num_microbatches): - model_chunk_id = self.get_model_chunk_id(i, forward=False) - self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) - # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + for i in range(num_microbatch_remaining, num_microbatch): + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) input_obj = input_objs[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0) output_obj_grad = self.recv_backward(model_chunk_id) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(model_chunk_id, input_obj_grad) + if not forward_only: + assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) + if outputs is not None: outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 63ac0373d543..ed61bfcba0c7 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -19,13 +19,20 @@ class PipelineStageManager: """ def __init__( - self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False, num_model_chunks=1 + self, + pg_mesh: ProcessGroupMesh, + pipeline_axis: int, + enable_interleave: bool = False, + num_model_chunks: int = 1, ) -> None: + assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" + self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord coord = self.pg_mesh.coordinate() # the prev rank of rank0 is the last rank @@ -34,10 +41,6 @@ def __init__( # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") - # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers - self.num_model_chunks = num_model_chunks - self.model_chunk_id = 0 - self.layers = Optional[List[List[int]]] # init p2p process groups stages = list(range(self.num_stages)) @@ -47,46 +50,60 @@ def __init__( ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - if is_virtual: + self.is_interleave = enable_interleave + if enable_interleave: + # use circle p2p communication # add the process group of the first rank and the last rank - # only used in interleaved pipeline for now group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) if self.stage in [stages[0], stages[-1]]: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - def is_first_stage(self) -> bool: - """Is the current stage the first stage. + # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers + self.num_model_chunks: int = num_model_chunks - Returns: - bool: Whether the current stage is the first stage. - """ - return self.stage == 0 and self.model_chunk_id == 0 + # for shardformer, hold stage indices of model + self.stage_indices: List[Tuple[int, int]] + # for shardformer, hold model chunk id + self.model_chunk_id: Optional[int] = None - def is_last_stage(self) -> bool: - """Is the current stage the last stage. - - Returns: - bool: Whether the current stage is the last stage. - """ - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: + """Is the current stage the first stage. - # introduced due to interleaved pipeline parallel, as the first/last device may also hold intermediate stages - def is_first_device(self) -> bool: - """Is the current stage on the first device. + NOTE: + 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. + 2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device() Returns: - bool: Whether the current stage is on the first device. + bool: Whether the current stage is the first stage. """ - return self.stage == 0 + if self.is_interleave and model_chunk_id is None: + model_chunk_id = self.model_chunk_id + assert self.is_interleave ^ (model_chunk_id is None), \ + "model_chunk_id must be specified when using interleaved pipeline" + if not self.is_interleave or model_chunk_id < 0: + return self.stage == 0 + else: + return self.stage == 0 and model_chunk_id == 0 + + def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: + """Is the current stage the last stage. - def is_last_device(self) -> bool: - """Is the current stage on the last device. + NOTE: + 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. + 2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device() Returns: - bool: Whether the current stage on the last device. + bool: Whether the current stage is the last stage. """ - return self.stage == self.num_stages - 1 + if self.is_interleave and model_chunk_id is None: + model_chunk_id = self.model_chunk_id + assert self.is_interleave ^ (model_chunk_id is None), \ + "model_chunk_id must be specified when using interleaved pipeline" + if not self.is_interleave or model_chunk_id < 0: + return self.stage == self.num_stages - 1 + else: + return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: @@ -154,17 +171,3 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: ProcessGroup: Process group of the given stages. """ return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) - - def set_interleaved_model_chunk_id(self, model_chunk_id: int): - """For interleaved pipeline parallel, set the model chunk id for the device at the current stage. - Args: - model_chunk_id (int): the id of the current model chunk for the device. - """ - self.model_chunk_id = model_chunk_id - - def set_interleaved_device_layers(self, layers: List[List[int]]): - """For interleaved pipeline parallel, set the layer chunks for the device. - Args: - layers (List[List[int]]): list of layer chunks for the device. - """ - self.layers = layers diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 789565e5a010..725441f25c27 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -2,14 +2,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch.nn as nn from torch import Tensor from torch.nn import Module -from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from ..layer.parallel_module import ParallelModule @@ -100,12 +99,6 @@ def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: return self.shard_config.pipeline_stage_manager return None - @property - def scheduler(self) -> Optional[PipelineSchedule]: - if self.shard_config is not None: - return self.shard_config.scheduler - return None - @abstractmethod def config_sanity_check(self): """ @@ -222,29 +215,31 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: @staticmethod def get_stage_index( - layers_per_stage: List[int], stage: int, num_stages=None, num_model_chunks=1 - ) -> Union[List[int], List[List[int]]]: - # num_stages info is only needed for interleaved pipeline stage assignment - if num_stages is None: - """ - get the start index and end index of layers for each stage. - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - start_idx = num_layers_per_stage_accumulated[stage] - end_idx = num_layers_per_stage_accumulated[stage + 1] - - return [start_idx, end_idx] - else: - """ - interleaved pipeline: get the start index and end index PAIRS for each stage. - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - stage_indexes = [] - for model_chunk in range(num_model_chunks): - start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] - end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] - stage_indexes.append([start_idx, end_idx]) - - return stage_indexes + layers_per_stage: List[int], + stage: int, + num_model_chunks: int = 1, + num_stages: int = 0, + ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: + """ + Get the start index and end index of layers for each stage. + + Args: + layers_per_stage (List[int]): number of layers for each stage + stage (int): the stage index + num_stages (int): number of stages + num_model_chunks (int): number of model chunks + + Returns: + - Tuple[int, int]: the start index and end index of this stage + - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk + + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + stage_indices = [] + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] + stage_indices.append([start_idx, end_idx]) + + return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 228489ea9aca..39fd437521c4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -21,7 +21,7 @@ "BertPolicy", "BertModelPolicy", "BertForPreTrainingPolicy", - "BertLMdHeadModelPolicy", + "BertLMHeadModelPolicy", "BertForMaskedLMPolicy", "BertForNextSentencePredictionPolicy", "BertForSequenceClassificationPolicy", @@ -242,48 +242,53 @@ def postprocess(self): return self.model def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "BertModel": - module = self.model - else: - module = self.model.bert - - # num_model_chunks > 1 if interleaved - num_model_chunks = stage_manager.num_model_chunks - layers_per_stage = Policy.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * num_model_chunks + """ + If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy. + """ + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages ) + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config + ) + } - if num_model_chunks > 1: - stage_index = Policy.get_stage_index( - layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks + else: + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config ) - else: - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - - if num_model_chunks == 1: - method_replacement = { - "forward": partial( - new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config, - ) - } - # for interleaved, stage index for each forward is chosen in scheduler - else: - stage_manager.set_interleaved_device_layers(stage_index) - method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) - } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) + } - return + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -296,30 +301,33 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - num_model_chunks = stage_manager.num_model_chunks - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * num_model_chunks - ) - if stage_manager.is_first_device(): - held_layers.append(module.embeddings) - if num_model_chunks > 1: - stage_index = Policy.get_stage_index( - layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks ) - else: - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - - # interleaved stage index for one device comes in pairs, e.g.[[0,3],[6,9]] - if all(isinstance(item, list) for item in stage_index): - for i in range(len(stage_index)): - start_idx, end_idx = stage_index[i] + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages + ) + if stage_manager.is_first_stage(-1): + held_layers.append(module.embeddings) + for start_idx, end_idx in stage_indices: held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(-1): + held_layers.append(module.pooler) + else: - start_idx, end_idx = stage_index + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) - - if stage_manager.is_last_device(): - held_layers.append(module.pooler) + if stage_manager.is_last_stage(): + held_layers.append(module.pooler) return held_layers @@ -510,7 +518,8 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_device(): + if stage_manager.is_last_stage( + None if not stage_manager.is_interleave else -1): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index ca9a17784013..a285874d218b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -4,7 +4,6 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager __all__ = ["ShardConfig"] @@ -18,7 +17,6 @@ class ShardConfig: Args: tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. - scheduler (Optional[PipelineSchedule]): If using interleaved pp, it's necessary to specify the scheduler for layer assignment for each device. enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. @@ -30,7 +28,6 @@ class ShardConfig: """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None - scheduler: Optional[PipelineSchedule] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_flash_attention: bool = False diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py index ef51f938dc4f..af87029d0986 100644 --- a/examples/language/bert/data.py +++ b/examples/language/bert/data.py @@ -88,20 +88,22 @@ def train_dataloader(self): ) def val_dataloader(self): + # TODO: drop_last is set to True for now to avoid error when using PP + # as the last batch may not be divisible by the number of microbatches if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) for x in self.eval_splits ] def test_dataloader(self): if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) for x in self.eval_splits ] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 8ab3eeb3110c..48b294dc3f7b 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,18 +57,14 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( + None if not booster.plugin.stage_manager.is_interleave else -1) accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] if use_pipeline: - """skip the last batch with batch size 31 for interleaved pipeline parallel - as the number of microbatches needs to be a multiple of pipeline parallel devices - """ - if booster.plugin.pp_style == "interleaved" and len(labels) < 32: - continue pg_mesh = booster.plugin.pg_mesh pp_group = booster.plugin.pp_group current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) @@ -93,6 +89,7 @@ def evaluate_subset(dataloader: DataLoader): elif current_rank in current_pp_group_ranks: object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) + metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) accum_loss.add_(object_list[1].to(get_current_device())) @@ -138,7 +135,8 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( + None if not booster.plugin.stage_manager.is_interleave else -1) print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 989f88c787b5..5034335ec9e6 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -4,6 +4,7 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn import colossalai @@ -11,31 +12,23 @@ from colossalai.interface import OptimizerWrapper from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all +NUM_LAYER = 8 +DIM = 4 + class MlpModel(nn.Module): def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(4, 8) - self.linear2 = nn.Linear(8, 8) - self.linear3 = nn.Linear(8, 8) - self.linear4 = nn.Linear(8, 8) - self.linear5 = nn.Linear(8, 8) - self.linear6 = nn.Linear(8, 8) - self.linear7 = nn.Linear(8, 8) - self.linear8 = nn.Linear(8, 4) + super().__init__() + self.layers = nn.ModuleList( + [nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)] + ) def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - x = self.linear4(x) - x = self.linear5(x) - x = self.linear6(x) - x = self.linear7(x) - x = self.linear8(x) + for layer in self.layers: + x = layer(x) return x @@ -44,121 +37,142 @@ def pp_linear_fwd( data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None, - num_chunks: int = None, model_chunk_id: int = None, ): - if stage_mgr.is_first_stage() and model_chunk_id == 0: + if stage_mgr.is_first_stage(model_chunk_id): return {"input_obj": forward(data)} - elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + elif stage_mgr.is_last_stage(model_chunk_id): return forward(input_obj) else: return {"input_obj": forward(input_obj)} -@parameterize("num_micro_batches", [4, 8, 12]) -def examine_pp(num_micro_batches): +def run_pp( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, +): """ This test is to examine the correctness of interleaved 1F1B, compared with torch. Be aware it contains some hardcodes. """ - world_size = torch.distributed.get_world_size() - local_rank = torch.distributed.get_rank() - seed_all(1453) - - NUM_MICRO_BATCHS = num_micro_batches - BATCH_SIZE = 24 - NUM_CHUNKS = 2 + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + port=port, + host="localhost" + ) # create model + seed_all(1453) torch_model = MlpModel().cuda() - pp_model = copy.deepcopy(torch_model).cuda() - DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 - pg_mesh = ProcessGroupMesh(1, world_size, 1) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True, num_model_chunks=NUM_CHUNKS) + pg_mesh = ProcessGroupMesh(world_size) + stage_manager = PipelineStageManager( + pg_mesh, + pipeline_axis=0, + enable_interleave=True, + num_model_chunks=num_model_chunk + ) schedule = InterleavedSchedule( stage_manager=stage_manager, - num_microbatches=NUM_MICRO_BATCHS, - microbatch_size=None, - num_model_chunks=NUM_CHUNKS, + num_model_chunks=num_model_chunk, + num_microbatch=num_microbatch, ) sharded_model = torch.nn.ModuleList() - for idx, (_, sub_model) in enumerate(pp_model.named_children()): - if idx % (world_size) == local_rank: + for idx, sub_model in enumerate(pp_model.layers): + if idx % world_size == rank: sub_model._forward = sub_model.forward sub_model.forward = MethodType( partial( - pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model) + pp_linear_fwd, + stage_mgr=stage_manager, + model_chunk_id=len(sharded_model) ), sub_model._forward, ) sharded_model.append(sub_model.cuda()) + assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct" # create optimizer - torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) - pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5)) - # create - seed_all(1453) - if local_rank == 0: - input_list = [torch.rand(BATCH_SIZE, 4).cuda()] - else: - input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] - torch.distributed.all_reduce(input_list[0]) + # create data + seed_all(115) + input_list = [torch.rand(batch_size, DIM).cuda()] + dist.all_reduce(input_list[0]) - criterion = lambda x, y: torch.mean(x) + def criterion(x, *args, **kwargs): return (x * x).mean() # forward and backward torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output, _) + torch_loss = criterion(torch_output) torch_loss.backward() pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + sharded_model, + iter(input_list), + criterion, + pp_optimizer, + return_loss=True, + return_outputs=True ) # check loss - if stage_manager.is_last_device(): + if stage_manager.is_last_stage(-1): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients - torch_grad = [] - for torch_p in torch_model.parameters(): - torch_grad.append(torch_p.grad.data) - - for idx, pp_p in enumerate(sharded_model.parameters()): - if idx < 2: - assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) - else: - assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + for i in range(num_model_chunk): + idx = world_size * i + rank + assert torch.allclose( + torch_model.layers[idx].weight.grad, + sharded_model[i].weight.grad + ) + assert torch.allclose( + torch_model.layers[idx].bias.grad, + sharded_model[i].bias.grad + ) # step torch_optimizer.step() pp_optimizer.step() # check updated param - torch_param = [] - for torch_p in torch_model.parameters(): - torch_param.append(torch_p.data) - for idx, pp_p in enumerate(sharded_model.parameters()): - if idx < 2: - assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) - else: - assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - examine_pp() + for i in range(num_model_chunk): + idx = world_size * i + rank + assert torch.allclose( + torch_model.layers[idx].weight, + sharded_model[i].weight + ) + assert torch.allclose( + torch_model.layers[idx].bias, + sharded_model[i].bias + ) @pytest.mark.dist +@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("batch_size", [12]) +@pytest.mark.parametrize("num_model_chunk", [2, 4]) @rerun_if_address_is_in_use() -def test_pp(): - spawn(run_dist, 4) +def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): + assert NUM_LAYER % num_model_chunk == 0 + spawn( + run_pp, + nprocs=NUM_LAYER // num_model_chunk, + num_microbatch=num_microbatch, + batch_size=batch_size, + num_model_chunk=num_model_chunk + ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)