From b1e995721d7d34a0083ee5825db4bc094e17101d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 10 Nov 2023 15:58:29 +0800 Subject: [PATCH 01/10] feat: modify hybrid plugin and stage mgr --- .../booster/plugin/hybrid_parallel_plugin.py | 22 ++++--- colossalai/pipeline/stage_manager.py | 60 ++++++------------- colossalai/shardformer/shard/shard_config.py | 4 +- 3 files changed, 35 insertions(+), 51 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 9cc0d74b3556..aa0ce061f8d4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -380,29 +380,37 @@ def __init__( self.stage_manager = None self.schedule = None self.custom_policy = custom_policy - self.pp_style = pp_style assert zero_stage in (0, 1, 2) if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager( - self.pg_mesh, PP_AXIS, is_virtual=True, num_model_chunks=num_model_chunks + self.pg_mesh, + pipeline_axis=PP_AXIS, + enable_interleave=True, + num_model_chunks=num_model_chunks ) - if self.pp_style == "interleaved": + if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" self.schedule = InterleavedSchedule( stage_manager=self.stage_manager, - num_microbatches=num_microbatches, + num_microbatch=num_microbatches, microbatch_size=microbatch_size, num_model_chunks=num_model_chunks, ) - else: + elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size ) + else: + raise NotImplementedError() + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) @@ -410,7 +418,7 @@ def __init__( self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, - scheduler=self.schedule, + pipeline_scheduler=self.schedule, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 63ac0373d543..31324f5fdb8a 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -19,13 +19,18 @@ class PipelineStageManager: """ def __init__( - self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False, num_model_chunks=1 + self, + pg_mesh: ProcessGroupMesh, + pipeline_axis: int, + enable_interleave: bool = False, + num_model_chunks: Optional[int] = None, ) -> None: self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord coord = self.pg_mesh.coordinate() # the prev rank of rank0 is the last rank @@ -34,10 +39,6 @@ def __init__( # the next rank of the last rank is rank0 next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") - # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers - self.num_model_chunks = num_model_chunks - self.model_chunk_id = 0 - self.layers = Optional[List[List[int]]] # init p2p process groups stages = list(range(self.num_stages)) @@ -47,46 +48,35 @@ def __init__( ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - if is_virtual: + self.is_interleave = enable_interleave + if enable_interleave: + # use circle p2p communication # add the process group of the first rank and the last rank - # only used in interleaved pipeline for now group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) if self.stage in [stages[0], stages[-1]]: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - def is_first_stage(self) -> bool: + # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers + self.num_model_chunks: int = num_model_chunks + + def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the first stage. Returns: bool: Whether the current stage is the first stage. """ - return self.stage == 0 and self.model_chunk_id == 0 + assert not self.is_interleave or model_chunk_id is not None + return self.stage and (not self.is_interleave or model_chunk_id == 0) - def is_last_stage(self) -> bool: + def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the last stage. Returns: bool: Whether the current stage is the last stage. """ - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 - - # introduced due to interleaved pipeline parallel, as the first/last device may also hold intermediate stages - def is_first_device(self) -> bool: - """Is the current stage on the first device. - - Returns: - bool: Whether the current stage is on the first device. - """ - return self.stage == 0 - - def is_last_device(self) -> bool: - """Is the current stage on the last device. - - Returns: - bool: Whether the current stage on the last device. - """ - return self.stage == self.num_stages - 1 + assert not self.is_interleave or model_chunk_id is not None + return self.stage == self.num_stages - 1 and (not self.is_interleave or model_chunk_id == self.num_model_chunks - 1) @property def num_stages(self) -> int: @@ -154,17 +144,3 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: ProcessGroup: Process group of the given stages. """ return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) - - def set_interleaved_model_chunk_id(self, model_chunk_id: int): - """For interleaved pipeline parallel, set the model chunk id for the device at the current stage. - Args: - model_chunk_id (int): the id of the current model chunk for the device. - """ - self.model_chunk_id = model_chunk_id - - def set_interleaved_device_layers(self, layers: List[List[int]]): - """For interleaved pipeline parallel, set the layer chunks for the device. - Args: - layers (List[List[int]]): list of layer chunks for the device. - """ - self.layers = layers diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index ca9a17784013..fbe038193345 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -18,7 +18,7 @@ class ShardConfig: Args: tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. - scheduler (Optional[PipelineSchedule]): If using interleaved pp, it's necessary to specify the scheduler for layer assignment for each device. + pipeline_scheduler (Optional[PipelineSchedule]): If using interleaved pp, it's necessary to specify the scheduler for layer assignment for each device. enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. @@ -30,7 +30,7 @@ class ShardConfig: """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None - scheduler: Optional[PipelineSchedule] = None + pipeline_scheduler: Optional[PipelineSchedule] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_flash_attention: bool = False From 37bd7913346ec4bf8e65c4ecd8da219b53ba9972 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 10 Nov 2023 18:16:57 +0800 Subject: [PATCH 02/10] feat: fix interleave test --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- .../pipeline/schedule/interleaved_pp.py | 161 +++++++---------- colossalai/pipeline/stage_manager.py | 14 +- .../test_schedule/test_interleaved.py | 170 ++++++++++-------- 4 files changed, 170 insertions(+), 177 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index aa0ce061f8d4..ac4acc4fac20 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -398,9 +398,9 @@ def __init__( assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" self.schedule = InterleavedSchedule( stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, microbatch_size=microbatch_size, - num_model_chunks=num_model_chunks, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 4e1286448589..cad33efa6872 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -19,23 +19,23 @@ class InterleavedSchedule(PipelineSchedule): def __init__( self, stage_manager: PipelineStageManager, - num_microbatches: Optional[int] = None, + num_model_chunks: int, + num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, - num_model_chunks: Optional[int] = 1, ) -> None: super().__init__(stage_manager) assert ( - num_microbatches is not None or microbatch_size is not None - ), "Either num_microbatches or microbatch_size should be provided" + num_microbatch is not None or microbatch_size is not None + ), "Either num_microbatch or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) - self.num_microbatches = num_microbatches + self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size - self.batch: Optional[Any] = None - self.batch_size: Optional[int] = None - self.microbatch_offset: Optional[int] = None - self._use_microbatch_size = num_microbatches is None self.num_model_chunks = num_model_chunks + self.batch: Any + self.batch_size: int + self.microbatch_offset: List[int] + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -49,22 +49,25 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - if not self._use_microbatch_size: + if self.num_microbatch is not None: assert ( - self.batch_size % self.num_microbatches == 0 - ), "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches - else: + self.batch_size % self.num_microbatch == 0 + ), "Batch size should divided by the number of microbatch" + self.microbatch_size = self.batch_size // self.num_microbatch + elif self.microbatch_size is not None: assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" - self.num_microbatches = self.batch_size // self.microbatch_size + self.num_microbatch = self.batch_size // self.microbatch_size + else: + raise ValueError( + "Either num_microbatch or microbatch_size should be provided") assert ( - self.num_microbatches % self.num_model_chunks == 0 - ), "Number of microbatches should be an integer multiple of number of model chunks" + self.num_microbatch % self.num_model_chunks == 0 + ), "Number of microbatch should be an integer multiple of number of model chunks" assert ( - self.num_microbatches % self.stage_manager.num_stages == 0 - ), "Number of microbatches should be an integer multiple of number of pipeline parallel devices" + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. @@ -79,7 +82,7 @@ def load_micro_batch(self, model_chunk_id: int) -> Any: self.microbatch_offset[model_chunk_id] += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) - def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. Args: @@ -91,36 +94,10 @@ def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: """ microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages - if not forward: + if not is_forward: model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id - def is_first_stage(self, model_chunk_id: int) -> bool: - """Is the current virtual stage the first stage - - Args: - model_chunk_id (int): The current model chunk idx. - - Returns: - bool: Whether the current virtual stage is the first stage. - """ - if self.stage_manager.is_first_device() and model_chunk_id == 0: - return True - return False - - def is_last_stage(self, model_chunk_id: int) -> bool: - """Is the current virtual stage the last stage - - Args: - model_chunk_id (int): The current model chunk idx. - - Returns: - bool: Whether the current virtual stage is the last stage. - """ - if self.stage_manager.is_last_device() and model_chunk_id == self.num_model_chunks - 1: - return True - return False - def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For interleaved 1F1B. @@ -132,7 +109,7 @@ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - if self.is_first_stage(model_chunk_id): + if self.stage_manager.is_first_stage(model_chunk_id): input_tensor = None else: input_tensor = self.comm.recv_forward(prev_rank) @@ -150,7 +127,7 @@ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.is_last_stage(model_chunk_id): + if self.stage_manager.is_last_stage(model_chunk_id): output_tensor_grad = None else: output_tensor_grad = self.comm.recv_backward(next_rank) @@ -166,7 +143,7 @@ def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.is_last_stage(model_chunk_id): + if not self.stage_manager.is_last_stage(model_chunk_id): self.comm.send_forward(output_object, next_rank) def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: @@ -178,7 +155,7 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.is_first_stage(model_chunk_id): + if not self.stage_manager.is_first_stage(model_chunk_id): self.comm.send_backward(input_object, prev_rank) def forward_step( @@ -209,14 +186,16 @@ def forward_step( if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) else: + # TODO + raise NotImplementedError() # in shardformer, each device still has the entire model, so we need to pass the relevant stage layers - if input_obj is None: - input_obj = {} - input_obj["stage_index"] = self.stage_manager.layers[model_chunk_id] - output_obj = model_forward(model_chunk, micro_batch, input_obj) + # if input_obj is None: + # input_obj = {} + # input_obj["stage_index"] = self.stage_manager.layers[model_chunk_id] + # output_obj = model_forward(model_chunk, micro_batch, input_obj) - if self.is_last_stage(model_chunk_id): - loss = criterion(output_obj, micro_batch) / self.num_microbatches + if self.stage_manager.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: @@ -279,7 +258,7 @@ def forward_backward_step( return_loss: bool = False, return_outputs: bool = False, ) -> dict: - """Runs interleaved 1F1B schedule, with communication between pipeline stages. + """Runs interleaved schedule, with communication between pipeline stages. Args: model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification @@ -297,46 +276,42 @@ def forward_backward_step( assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) - num_model_chunks = self.num_model_chunks - # num_warmup_microbatches is the step when not all the processes are working - num_microbatches = self.num_microbatches * num_model_chunks + num_microbatch = self.num_microbatch * self.num_model_chunks if forward_only: - num_warmup_microbatches = num_microbatches + num_warmup_microbatch = num_microbatch else: - num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) - num_microbatches_remaining = num_microbatches - num_warmup_microbatches + num_microbatch_remaining = num_microbatch - num_warmup_microbatch # Input, output tensors only need to be saved when doing backward passes input_objs = None output_objs = None if not forward_only: - input_objs = [[] for _ in range(num_model_chunks)] - output_objs = [[] for _ in range(num_model_chunks)] + input_objs = [[] for _ in range(self.num_model_chunks)] + output_objs = [[] for _ in range(self.num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_device() else None + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None - if return_loss and self.stage_manager.is_last_device(): + if return_loss and self.stage_manager.is_last_stage(): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None # for ranks except the first one, get into recv state - # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) input_obj = self.recv_forward(0) if not forward_only: input_objs[0].append(input_obj) # Run warmup forward passes. - for i in range(num_warmup_microbatches): - model_chunk_id = self.get_model_chunk_id(i, forward=True) - self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) - # recv first on first rank to avoid sending or recving at the same time - if self.stage_manager.is_first_device(): + for i in range(num_warmup_microbatch): + model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + # recv first on first rank to avoid sending or receiving at the same time + if self.stage_manager.is_first_stage(): input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) self.send_forward(model_chunk_id, output_obj) @@ -348,20 +323,20 @@ def forward_backward_step( if not forward_only: output_objs[model_chunk_id].append(output_obj) self.send_forward(model_chunk_id, output_obj) - if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: - break - else: - model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) - input_obj = self.recv_forward(model_chunk_id) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) + if num_microbatch_remaining == 0 \ + and i + 1 == num_warmup_microbatch: + break + + model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) # Run 1F1B in steady state. - for i in range(num_microbatches_remaining): - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) - self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) - last_iteration = i == (num_microbatches_remaining - 1) + for i in range(num_microbatch_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + last_iteration = i == num_microbatch_remaining - 1 output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if forward_only: @@ -376,7 +351,7 @@ def forward_backward_step( input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) - model_chunk_id = self.get_model_chunk_id(i, forward=False) + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) output_obj_grad = self.recv_backward(model_chunk_id) # Pop output_obj and output_obj from the start of the list for @@ -390,17 +365,15 @@ def forward_backward_step( if last_iteration: input_obj = None else: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) input_obj = self.recv_forward(model_chunk_id) - model_chunk_id = self.get_model_chunk_id(i, forward=False) + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) self.send_backward(model_chunk_id, input_obj_grad) # Run cooldown backward passes. if not forward_only: - for i in range(num_microbatches_remaining, num_microbatches): - model_chunk_id = self.get_model_chunk_id(i, forward=False) - self.stage_manager.set_interleaved_model_chunk_id(model_chunk_id) - # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + for i in range(num_microbatch_remaining, num_microbatch): + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) input_obj = input_objs[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0) output_obj_grad = self.recv_backward(model_chunk_id) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 31324f5fdb8a..5abc08cb7625 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -66,8 +66,11 @@ def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: Returns: bool: Whether the current stage is the first stage. """ - assert not self.is_interleave or model_chunk_id is not None - return self.stage and (not self.is_interleave or model_chunk_id == 0) + assert self.is_interleave or model_chunk_id is None + if not self.is_interleave or model_chunk_id is None: + return self.stage == 0 + else: + return self.stage == 0 and model_chunk_id == 0 def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the last stage. @@ -75,8 +78,11 @@ def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: Returns: bool: Whether the current stage is the last stage. """ - assert not self.is_interleave or model_chunk_id is not None - return self.stage == self.num_stages - 1 and (not self.is_interleave or model_chunk_id == self.num_model_chunks - 1) + assert self.is_interleave or model_chunk_id is None + if not self.is_interleave or model_chunk_id is None: + return self.stage == self.num_stages - 1 + else: + return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 989f88c787b5..015462272a60 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -4,6 +4,7 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn import colossalai @@ -11,31 +12,23 @@ from colossalai.interface import OptimizerWrapper from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all +NUM_LAYER = 8 +DIM = 4 + class MlpModel(nn.Module): def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(4, 8) - self.linear2 = nn.Linear(8, 8) - self.linear3 = nn.Linear(8, 8) - self.linear4 = nn.Linear(8, 8) - self.linear5 = nn.Linear(8, 8) - self.linear6 = nn.Linear(8, 8) - self.linear7 = nn.Linear(8, 8) - self.linear8 = nn.Linear(8, 4) + super().__init__() + self.layers = nn.ModuleList( + [nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)] + ) def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - x = self.linear4(x) - x = self.linear5(x) - x = self.linear6(x) - x = self.linear7(x) - x = self.linear8(x) + for layer in self.layers: + x = layer(x) return x @@ -44,121 +37,142 @@ def pp_linear_fwd( data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None, - num_chunks: int = None, model_chunk_id: int = None, ): - if stage_mgr.is_first_stage() and model_chunk_id == 0: + if stage_mgr.is_first_stage(model_chunk_id): return {"input_obj": forward(data)} - elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + elif stage_mgr.is_last_stage(model_chunk_id): return forward(input_obj) else: return {"input_obj": forward(input_obj)} -@parameterize("num_micro_batches", [4, 8, 12]) -def examine_pp(num_micro_batches): +def run_pp( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, +): """ This test is to examine the correctness of interleaved 1F1B, compared with torch. Be aware it contains some hardcodes. """ - world_size = torch.distributed.get_world_size() - local_rank = torch.distributed.get_rank() - seed_all(1453) - - NUM_MICRO_BATCHS = num_micro_batches - BATCH_SIZE = 24 - NUM_CHUNKS = 2 + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + port=port, + host="localhost" + ) # create model + seed_all(1453) torch_model = MlpModel().cuda() - pp_model = copy.deepcopy(torch_model).cuda() - DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 - pg_mesh = ProcessGroupMesh(1, world_size, 1) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True, num_model_chunks=NUM_CHUNKS) + pg_mesh = ProcessGroupMesh(world_size) + stage_manager = PipelineStageManager( + pg_mesh, + pipeline_axis=0, + enable_interleave=True, + num_model_chunks=num_model_chunk + ) schedule = InterleavedSchedule( stage_manager=stage_manager, - num_microbatches=NUM_MICRO_BATCHS, - microbatch_size=None, - num_model_chunks=NUM_CHUNKS, + num_model_chunks=num_model_chunk, + num_microbatch=num_microbatch, ) sharded_model = torch.nn.ModuleList() - for idx, (_, sub_model) in enumerate(pp_model.named_children()): - if idx % (world_size) == local_rank: + for idx, sub_model in enumerate(pp_model.layers): + if idx % world_size == rank: sub_model._forward = sub_model.forward sub_model.forward = MethodType( partial( - pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model) + pp_linear_fwd, + stage_mgr=stage_manager, + model_chunk_id=len(sharded_model) ), sub_model._forward, ) sharded_model.append(sub_model.cuda()) + assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct" # create optimizer - torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) - pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5)) - # create - seed_all(1453) - if local_rank == 0: - input_list = [torch.rand(BATCH_SIZE, 4).cuda()] - else: - input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] - torch.distributed.all_reduce(input_list[0]) + # create data + seed_all(115) + input_list = [torch.rand(batch_size, DIM).cuda()] + dist.all_reduce(input_list[0]) - criterion = lambda x, y: torch.mean(x) + def criterion(x, *args, **kwargs): return torch.mean(x) # forward and backward torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output, _) + torch_loss = criterion(torch_output) torch_loss.backward() pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + sharded_model, + iter(input_list), + criterion, + pp_optimizer, + return_loss=True, + return_outputs=True ) # check loss - if stage_manager.is_last_device(): + if stage_manager.is_last_stage(): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients - torch_grad = [] - for torch_p in torch_model.parameters(): - torch_grad.append(torch_p.grad.data) - - for idx, pp_p in enumerate(sharded_model.parameters()): - if idx < 2: - assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) - else: - assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + for i in range(num_model_chunk): + idx = world_size * i + rank + assert torch.allclose( + torch_model.layers[idx].weight.grad, + sharded_model[i].weight.grad + ) + assert torch.allclose( + torch_model.layers[idx].bias.grad, + sharded_model[i].bias.grad + ) # step torch_optimizer.step() pp_optimizer.step() # check updated param - torch_param = [] - for torch_p in torch_model.parameters(): - torch_param.append(torch_p.data) - for idx, pp_p in enumerate(sharded_model.parameters()): - if idx < 2: - assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) - else: - assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - examine_pp() + for i in range(num_model_chunk): + idx = world_size * i + rank + assert torch.allclose( + torch_model.layers[idx].weight, + sharded_model[i].weight + ) + assert torch.allclose( + torch_model.layers[idx].bias, + sharded_model[i].bias + ) @pytest.mark.dist +@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("batch_size", [12]) +@pytest.mark.parametrize("num_model_chunk", [2, 4]) @rerun_if_address_is_in_use() -def test_pp(): - spawn(run_dist, 4) +def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): + assert NUM_LAYER % num_model_chunk == 0 + spawn( + run_pp, + nprocs=NUM_LAYER // num_model_chunk, + num_microbatch=num_microbatch, + batch_size=batch_size, + num_model_chunk=num_model_chunk + ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4) From 4993e8d44735ddd0ca2191a35904d276cb5bdb7c Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 13 Nov 2023 14:46:54 +0800 Subject: [PATCH 03/10] style: remove unused code and add comments --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 1 - colossalai/pipeline/stage_manager.py | 8 ++++++++ colossalai/shardformer/policies/base_policy.py | 7 ------- colossalai/shardformer/shard/shard_config.py | 3 --- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ac4acc4fac20..44eb70836c23 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -418,7 +418,6 @@ def __init__( self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, - pipeline_scheduler=self.schedule, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 5abc08cb7625..08b49380c60d 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -63,6 +63,10 @@ def __init__( def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the first stage. + NOTE: + 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. + 2. invoke is_first_stage() with model_chunk_id=None is equivalent to invoke is_first_device() + Returns: bool: Whether the current stage is the first stage. """ @@ -75,6 +79,10 @@ def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the last stage. + NOTE: + 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. + 2. invoke is_last_stage() with model_chunk_id=None is equivalent to invoke is_last_device() + Returns: bool: Whether the current stage is the last stage. """ diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 789565e5a010..5c7c89b77c23 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -9,7 +9,6 @@ from torch import Tensor from torch.nn import Module -from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from ..layer.parallel_module import ParallelModule @@ -100,12 +99,6 @@ def pipeline_stage_manager(self) -> Optional[PipelineStageManager]: return self.shard_config.pipeline_stage_manager return None - @property - def scheduler(self) -> Optional[PipelineSchedule]: - if self.shard_config is not None: - return self.shard_config.scheduler - return None - @abstractmethod def config_sanity_check(self): """ diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index fbe038193345..a285874d218b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -4,7 +4,6 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.pipeline.schedule import PipelineSchedule from colossalai.pipeline.stage_manager import PipelineStageManager __all__ = ["ShardConfig"] @@ -18,7 +17,6 @@ class ShardConfig: Args: tensor_parallel_process_group (Optional[ProcessGroup]): The process group of tensor parallelism, it's necessary when using tensor parallel. Defaults to None, which is the global process group. pipeline_stage_manager (Optional[PipelineStageManager]): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism. - pipeline_scheduler (Optional[PipelineSchedule]): If using interleaved pp, it's necessary to specify the scheduler for layer assignment for each device. enable_tensor_parallelism (bool): Whether to use tensor parallelism. Defaults to True. enable_fused_normalization (bool): Whether to use fused layernorm. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. @@ -30,7 +28,6 @@ class ShardConfig: """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None - pipeline_scheduler: Optional[PipelineSchedule] = None enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_flash_attention: bool = False From a25e53df553b94319af039de4ab355f7acbac50d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 13 Nov 2023 18:21:41 +0800 Subject: [PATCH 04/10] feat: modify bert policy --- .../shardformer/policies/base_policy.py | 56 ++++---- colossalai/shardformer/policies/bert.py | 128 ++++++++++-------- 2 files changed, 97 insertions(+), 87 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 5c7c89b77c23..725441f25c27 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch.nn as nn @@ -215,29 +215,31 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]: @staticmethod def get_stage_index( - layers_per_stage: List[int], stage: int, num_stages=None, num_model_chunks=1 - ) -> Union[List[int], List[List[int]]]: - # num_stages info is only needed for interleaved pipeline stage assignment - if num_stages is None: - """ - get the start index and end index of layers for each stage. - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - start_idx = num_layers_per_stage_accumulated[stage] - end_idx = num_layers_per_stage_accumulated[stage + 1] - - return [start_idx, end_idx] - else: - """ - interleaved pipeline: get the start index and end index PAIRS for each stage. - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - stage_indexes = [] - for model_chunk in range(num_model_chunks): - start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] - end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] - stage_indexes.append([start_idx, end_idx]) - - return stage_indexes + layers_per_stage: List[int], + stage: int, + num_model_chunks: int = 1, + num_stages: int = 0, + ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: + """ + Get the start index and end index of layers for each stage. + + Args: + layers_per_stage (List[int]): number of layers for each stage + stage (int): the stage index + num_stages (int): number of stages + num_model_chunks (int): number of model chunks + + Returns: + - Tuple[int, int]: the start index and end index of this stage + - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk + + """ + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + stage_indices = [] + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] + stage_indices.append([start_idx, end_idx]) + + return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 228489ea9aca..1e1803d64bc6 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -21,7 +21,7 @@ "BertPolicy", "BertModelPolicy", "BertForPreTrainingPolicy", - "BertLMdHeadModelPolicy", + "BertLMHeadModelPolicy", "BertForMaskedLMPolicy", "BertForNextSentencePredictionPolicy", "BertForSequenceClassificationPolicy", @@ -242,48 +242,53 @@ def postprocess(self): return self.model def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "BertModel": - module = self.model - else: - module = self.model.bert - - # num_model_chunks > 1 if interleaved - num_model_chunks = stage_manager.num_model_chunks - layers_per_stage = Policy.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * num_model_chunks + """ + If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy. + """ + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks ) + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages + ) + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config + ) + } - if num_model_chunks > 1: - stage_index = Policy.get_stage_index( - layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks + else: + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config ) - else: - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - - if num_model_chunks == 1: - method_replacement = { - "forward": partial( - new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config, - ) - } - # for interleaved, stage index for each forward is chosen in scheduler - else: - stage_manager.set_interleaved_device_layers(stage_index) - method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) - } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) + } - return + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -296,30 +301,33 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - num_model_chunks = stage_manager.num_model_chunks - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * num_model_chunks - ) - if stage_manager.is_first_device(): - held_layers.append(module.embeddings) - if num_model_chunks > 1: - stage_index = Policy.get_stage_index( - layers_per_stage, stage_manager.stage, stage_manager.num_stages, num_model_chunks + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks ) - else: - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - - # interleaved stage index for one device comes in pairs, e.g.[[0,3],[6,9]] - if all(isinstance(item, list) for item in stage_index): - for i in range(len(stage_index)): - start_idx, end_idx = stage_index[i] + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages + ) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + for start_idx, end_idx in stage_indices: held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.pooler) + else: - start_idx, end_idx = stage_index + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) - - if stage_manager.is_last_device(): - held_layers.append(module.pooler) + if stage_manager.is_last_stage(): + held_layers.append(module.pooler) return held_layers @@ -510,7 +518,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_device(): + if stage_manager.is_last_stage(): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers From 1b63f037294ec13a39dbba46ea878f295accd665 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 13 Nov 2023 18:22:38 +0800 Subject: [PATCH 05/10] fix: fix bert shard forward --- colossalai/pipeline/schedule/interleaved_pp.py | 12 +++++------- colossalai/pipeline/stage_manager.py | 7 ++++++- colossalai/shardformer/modeling/bert.py | 3 ++- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index cad33efa6872..acfa93cf0762 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -186,13 +186,11 @@ def forward_step( if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) else: - # TODO - raise NotImplementedError() - # in shardformer, each device still has the entire model, so we need to pass the relevant stage layers - # if input_obj is None: - # input_obj = {} - # input_obj["stage_index"] = self.stage_manager.layers[model_chunk_id] - # output_obj = model_forward(model_chunk, micro_batch, input_obj) + # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + internal_inputs["model_chunk_id"] = model_chunk_id + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) if self.stage_manager.is_last_stage(model_chunk_id): loss = criterion(output_obj, micro_batch) / self.num_microbatch diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 08b49380c60d..0bce7ae9d0c1 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -23,8 +23,10 @@ def __init__( pg_mesh: ProcessGroupMesh, pipeline_axis: int, enable_interleave: bool = False, - num_model_chunks: Optional[int] = None, + num_model_chunks: int = 1, ) -> None: + assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" + self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None @@ -60,6 +62,9 @@ def __init__( # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks + # for shardformer, hold stage indices of model + self.stage_indices: List[Tuple[int, int]] + def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the first stage. diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 7411e1d0ec46..43febcc4ce74 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -661,6 +661,7 @@ def bert_for_sequence_classification_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + model_chunk_id: Optional[int] = None, shard_config: ShardConfig = None, ): r""" @@ -697,7 +698,7 @@ def bert_for_sequence_classification_forward( shard_config=shard_config, ) - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(model_chunk_id): pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) From b1a8c8af98dadf529eb77a63f40b9ade7f644fe5 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 13 Nov 2023 18:24:06 +0800 Subject: [PATCH 06/10] fix: fix finetuning args --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 3 ++- examples/language/bert/data.py | 8 ++++---- examples/language/bert/finetune.py | 10 +++------- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 44eb70836c23..e999c27d672f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -383,6 +383,7 @@ def __init__( assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -390,7 +391,7 @@ def __init__( self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=PP_AXIS, - enable_interleave=True, + enable_interleave=pp_style == "interleaved", num_model_chunks=num_model_chunks ) diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py index ef51f938dc4f..0309864263f7 100644 --- a/examples/language/bert/data.py +++ b/examples/language/bert/data.py @@ -89,19 +89,19 @@ def train_dataloader(self): def val_dataloader(self): if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) for x in self.eval_splits ] def test_dataloader(self): if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) for x in self.eval_splits ] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 8ab3eeb3110c..9f15a518712f 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,18 +57,13 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage() accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] if use_pipeline: - """skip the last batch with batch size 31 for interleaved pipeline parallel - as the number of microbatches needs to be a multiple of pipeline parallel devices - """ - if booster.plugin.pp_style == "interleaved" and len(labels) < 32: - continue pg_mesh = booster.plugin.pg_mesh pp_group = booster.plugin.pp_group current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) @@ -93,6 +88,7 @@ def evaluate_subset(dataloader: DataLoader): elif current_rank in current_pp_group_ranks: object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) + metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) accum_loss.add_(object_list[1].to(get_current_device())) @@ -138,7 +134,7 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_device() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage() print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) From 60855a96f2559080ae457bb1034c834e4babb1b5 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 14 Nov 2023 17:38:12 +0800 Subject: [PATCH 07/10] fix: fix is_xx_stage fn --- .../pipeline/schedule/interleaved_pp.py | 9 +++++---- colossalai/pipeline/stage_manager.py | 20 +++++++++++++------ colossalai/shardformer/modeling/bert.py | 3 +-- colossalai/shardformer/policies/bert.py | 7 ++++--- examples/language/bert/finetune.py | 6 ++++-- .../test_schedule/test_interleaved.py | 2 +- 6 files changed, 29 insertions(+), 18 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index acfa93cf0762..6a6122228771 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -183,14 +183,15 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + self.stage_manager.model_chunk_id = model_chunk_id if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) else: # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers internal_inputs = {} if input_obj is None else input_obj internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - internal_inputs["model_chunk_id"] = model_chunk_id output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + self.stage_manager.model_chunk_id = None if self.stage_manager.is_last_stage(model_chunk_id): loss = criterion(output_obj, micro_batch) / self.num_microbatch @@ -293,9 +294,9 @@ def forward_backward_step( input_objs = [[] for _ in range(self.num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None - if return_loss and self.stage_manager.is_last_stage(): + if return_loss and self.stage_manager.is_last_stage(-1): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None @@ -309,7 +310,7 @@ def forward_backward_step( for i in range(num_warmup_microbatch): model_chunk_id = self.get_model_chunk_id(i, is_forward=True) # recv first on first rank to avoid sending or receiving at the same time - if self.stage_manager.is_first_stage(): + if self.stage_manager.is_first_stage(-1): input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) self.send_forward(model_chunk_id, output_obj) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 0bce7ae9d0c1..ed61bfcba0c7 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -64,19 +64,24 @@ def __init__( # for shardformer, hold stage indices of model self.stage_indices: List[Tuple[int, int]] + # for shardformer, hold model chunk id + self.model_chunk_id: Optional[int] = None def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the first stage. NOTE: 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. - 2. invoke is_first_stage() with model_chunk_id=None is equivalent to invoke is_first_device() + 2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device() Returns: bool: Whether the current stage is the first stage. """ - assert self.is_interleave or model_chunk_id is None - if not self.is_interleave or model_chunk_id is None: + if self.is_interleave and model_chunk_id is None: + model_chunk_id = self.model_chunk_id + assert self.is_interleave ^ (model_chunk_id is None), \ + "model_chunk_id must be specified when using interleaved pipeline" + if not self.is_interleave or model_chunk_id < 0: return self.stage == 0 else: return self.stage == 0 and model_chunk_id == 0 @@ -86,13 +91,16 @@ def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: NOTE: 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. - 2. invoke is_last_stage() with model_chunk_id=None is equivalent to invoke is_last_device() + 2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device() Returns: bool: Whether the current stage is the last stage. """ - assert self.is_interleave or model_chunk_id is None - if not self.is_interleave or model_chunk_id is None: + if self.is_interleave and model_chunk_id is None: + model_chunk_id = self.model_chunk_id + assert self.is_interleave ^ (model_chunk_id is None), \ + "model_chunk_id must be specified when using interleaved pipeline" + if not self.is_interleave or model_chunk_id < 0: return self.stage == self.num_stages - 1 else: return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 43febcc4ce74..7411e1d0ec46 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -661,7 +661,6 @@ def bert_for_sequence_classification_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, - model_chunk_id: Optional[int] = None, shard_config: ShardConfig = None, ): r""" @@ -698,7 +697,7 @@ def bert_for_sequence_classification_forward( shard_config=shard_config, ) - if stage_manager.is_last_stage(model_chunk_id): + if stage_manager.is_last_stage(): pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 1e1803d64bc6..39fd437521c4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -313,11 +313,11 @@ def get_held_layers(self) -> List[Module]: num_model_chunks=stage_manager.num_model_chunks, num_stages=stage_manager.num_stages ) - if stage_manager.is_first_stage(): + if stage_manager.is_first_stage(-1): held_layers.append(module.embeddings) for start_idx, end_idx in stage_indices: held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(-1): held_layers.append(module.pooler) else: @@ -518,7 +518,8 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage( + None if not stage_manager.is_interleave else -1): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 9f15a518712f..48b294dc3f7b 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,7 +57,8 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( + None if not booster.plugin.stage_manager.is_interleave else -1) accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: @@ -134,7 +135,8 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( + None if not booster.plugin.stage_manager.is_interleave else -1) print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 015462272a60..7bc88438dd96 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -126,7 +126,7 @@ def criterion(x, *args, **kwargs): return torch.mean(x) ) # check loss - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(-1): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients From 1533b636d049f064687b2d53f07862cfc7b21616 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 14 Nov 2023 17:40:23 +0800 Subject: [PATCH 08/10] fix: disable fused layer --- examples/language/bert/finetune.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 48b294dc3f7b..f3535145da75 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -228,7 +228,10 @@ def main(): pp_style="interleaved", num_model_chunks=2, microbatch_size=16, - enable_all_optimization=True, + # FIXME: TODO: + # This is disable as fused layer lead to results with huge error + # The result can be retrieved at colossalai/shardformer/modeling/bert.py:224 + # enable_all_optimization=True, zero_stage=1, precision="fp16", initial_scale=1, From a64af2c4e6b00beac8bcb74c49342510d6d19485 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 15 Nov 2023 14:51:52 +0800 Subject: [PATCH 09/10] fix: fix wrong input_objs --- colossalai/pipeline/schedule/interleaved_pp.py | 9 +++++---- examples/language/bert/finetune.py | 5 +---- tests/test_pipeline/test_schedule/test_interleaved.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 6a6122228771..bc66d2f262a4 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -303,8 +303,6 @@ def forward_backward_step( # for ranks except the first one, get into recv state input_obj = self.recv_forward(0) - if not forward_only: - input_objs[0].append(input_obj) # Run warmup forward passes. for i in range(num_warmup_microbatch): @@ -320,6 +318,7 @@ def forward_backward_step( else: output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if not forward_only: + input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) self.send_forward(model_chunk_id, output_obj) @@ -329,8 +328,6 @@ def forward_backward_step( model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True) input_obj = self.recv_forward(model_chunk_id) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) # Run 1F1B in steady state. for i in range(num_microbatch_remaining): @@ -366,6 +363,7 @@ def forward_backward_step( else: model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) self.send_backward(model_chunk_id, input_obj_grad) @@ -379,6 +377,9 @@ def forward_backward_step( input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(model_chunk_id, input_obj_grad) + if not forward_only: + assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) + if outputs is not None: outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index f3535145da75..48b294dc3f7b 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -228,10 +228,7 @@ def main(): pp_style="interleaved", num_model_chunks=2, microbatch_size=16, - # FIXME: TODO: - # This is disable as fused layer lead to results with huge error - # The result can be retrieved at colossalai/shardformer/modeling/bert.py:224 - # enable_all_optimization=True, + enable_all_optimization=True, zero_stage=1, precision="fp16", initial_scale=1, diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 7bc88438dd96..5034335ec9e6 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -109,7 +109,7 @@ def run_pp( input_list = [torch.rand(batch_size, DIM).cuda()] dist.all_reduce(input_list[0]) - def criterion(x, *args, **kwargs): return torch.mean(x) + def criterion(x, *args, **kwargs): return (x * x).mean() # forward and backward torch_output = torch_model(input_list[0]) From 597e028718957ccfc5282e6012a5f6a731db0a43 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 16 Nov 2023 10:37:13 +0800 Subject: [PATCH 10/10] to: add TODO mark --- colossalai/pipeline/schedule/interleaved_pp.py | 1 + examples/language/bert/data.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index bc66d2f262a4..5471419958e1 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -270,6 +270,7 @@ def forward_backward_step( Returns: dict: A dict with keys: 'loss' and 'outputs'. """ + # TODO: handle arbitrary batch size when forward_only == True forward_only = not torch.is_grad_enabled() if optimizer is None: assert forward_only, "Optimizer should be passed when doing backward." diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py index 0309864263f7..af87029d0986 100644 --- a/examples/language/bert/data.py +++ b/examples/language/bert/data.py @@ -88,6 +88,8 @@ def train_dataloader(self): ) def val_dataloader(self): + # TODO: drop_last is set to True for now to avoid error when using PP + # as the last batch may not be divisible by the number of microbatches if len(self.eval_splits) == 1: return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True) elif len(self.eval_splits) > 1: