From 691193850d9c86241314e35e773d173111c076cd Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 03:48:50 +0000 Subject: [PATCH 01/47] hybrid support zbv --- .../naive_amp/mixed_precision_mixin/base.py | 2 +- .../naive_amp/mixed_precision_optimizer.py | 4 +- .../booster/plugin/hybrid_parallel_plugin.py | 46 +++++--- .../pipeline/schedule/zero_bubble_pp.py | 105 +++++++++++++----- colossalai/pipeline/stage_manager.py | 17 ++- colossalai/shardformer/modeling/llama.py | 1 + colossalai/shardformer/policies/llama.py | 12 +- colossalai/zero/gemini/gemini_ddp.py | 2 +- colossalai/zero/gemini/gemini_optimizer.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 8 +- tests/test_shardformer/test_model/_utils.py | 13 ++- .../test_model/test_shard_llama.py | 53 +++++++-- 12 files changed, 196 insertions(+), 71 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index fc7e0b74179a..b2ba47f6762d 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -43,7 +43,7 @@ def zero_grad(self): dtype: torch.dtype @abstractmethod - def pre_backward(self, loss: Tensor) -> Tensor: + def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor: """Called before backward. Args: diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 9e07bdebf8fa..700f80336cf0 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -89,9 +89,9 @@ def backward(self, loss: Tensor, *args, **kwargs): loss = self.mixed_precision.pre_backward(loss) loss.backward(*args, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) - tensor.backward(grad) + tensor.backward(grad, inputs=inputs, retain_graph=retain_graph) def zero_grad(self, *args, **kwargs): for p in self.working_to_master_map.keys(): diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1b3b765c2ff0..1323d14b320c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,7 +28,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -39,6 +39,7 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from .pp_plugin_base import PipelinePluginBase @@ -315,7 +316,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -332,7 +333,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -538,7 +539,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -554,7 +555,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): None """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -768,7 +769,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: else: return - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): """ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -784,7 +785,7 @@ def backward(self, loss, retain_graph=False): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) + super().backward(loss, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -793,7 +794,7 @@ def backward(self, loss, retain_graph=False): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -809,7 +810,7 @@ def backward_by_grad(self, tensor, grad): None """ # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1013,6 +1014,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -1029,6 +1031,7 @@ def __init__( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert not pp_style == "zbv" or scheduler_nodes is not None, f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1088,12 +1091,13 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert pp_style in ["interleaved", "zbv"] or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2, "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -1103,14 +1107,15 @@ def __init__( self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=(pp_style == "interleaved"), + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), + use_zbv=(pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -1119,12 +1124,20 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + self.scheduler = ZeroBubbleVPipeScheduler( + stage_manager=self.stage_manager, + schedule = scheduler_nodes, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": @@ -1236,7 +1249,6 @@ def configure( # Replace with distributed implementation if exists optimizer = cast_to_distributed(optimizer) - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: self.logger.warning( "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", @@ -1352,7 +1364,7 @@ def execute_pipeline( ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx, model._wait_all_gather(): - outputs = self.schedule.forward_backward_step( + outputs = self.scheduler.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c1c4f13c68c2..03196c48c311 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,7 +12,7 @@ from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device, model_forward from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -33,10 +33,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): """ if (out is None) or (not deallocate_pipeline_outputs): return + print(f"{out=}") assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." - # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) - out.data.untyped_storage().resize_(0) + out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + # out.data.untyped_storage().resize_(0) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -59,10 +60,9 @@ def __init__( self.batch_size: int self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] - - self.schedules = schedule # TODO: optim post valid self.do_post_validation = False + self.schedules = schedule # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache @@ -432,13 +432,17 @@ def forward_step( # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. # Only attention_mask from micro_batch is used + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate - output_obj = model_chunk[model_chunk_id](input_obj) + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + # output_obj = model_chunk[model_chunk_id](input_obj) # last layer in model - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - loss = criterion(output_obj) / self.num_microbatch + if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: @@ -472,19 +476,50 @@ def backward_b_step( # calculate bwd b step ; only dx = w*dy; # Retain the grad on the input_obj. - tree_map(retain_grad, input_obj) + if input_obj is None: + return None + else: + tree_map(retain_grad, input_obj) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - optimizer.backward_by_grad( - tensor=output_obj, - grad=output_obj_grad, - inputs=input_obj, - retain_graph=True, - ) - return input_obj.grad + print(f"{input_obj=}") + print(f"{output_obj=}") + + if "hidden_states" in input_obj.keys(): + input_obj_ = input_obj["hidden_states"] + else: + input_obj_ = input_obj["input_ids"] + + if output_obj_grad is None: + optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) + # optimizer.backward_by_grad( + # tensor=output_obj, + # grad=output_obj_grad, + # inputs=input_obj, + # retain_graph=True, + # ) + else: + output_obj_ = output_obj["hidden_states"] + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad, + inputs=input_obj_, + retain_graph=True, + ) + + # if "backward_tensor_keys" not in output_obj: + # for k, grad in output_obj_grad.items(): + # optimizer.backward_by_grad(output_obj[k], grad, inputs=input_obj_, retain_graph=True) + # else: + # for k, grad in output_obj_grad.items(): + # output_obj[k].grad = grad + # for k in output_obj["backward_tensor_keys"]: + # tensor_to_backward = output_obj[k] + # optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad, inputs=input_obj_, retain_graph=True) + return input_obj_.grad def backward_w_step( self, @@ -511,12 +546,20 @@ def backward_w_step( if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss output_obj_grad = None - optimizer.backward_by_grad( - tensor=output_obj, - grad=output_obj_grad, - inputs=list(model_chunk[model_chunk_id].parameters()), - retain_graph=False, - ) + + if output_obj_grad is None: + print(optimizer) + # optimizer.backward(output_obj, inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=True) + optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=True) + else: + output_obj_ = output_obj["hidden_states"] + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad, + # inputs=list(model_chunk[model_chunk_id].parameters()), + inputs=list(model_chunk.parameters()), + retain_graph=False, + ) def schedule_f( self, @@ -540,12 +583,11 @@ def schedule_f( Returns: Nothing. """ - micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # Step1: recv fwd if model_chunk_id == 0: # is first stage; get input from func param if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = micro_batch + input_obj = None else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: @@ -557,7 +599,9 @@ def schedule_f( input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Here, let input_obj.requires_grad_() - tree_map(torch.Tensor.requires_grad_, input_obj) + # add input and output object for backward b + if input_obj is not None: + tree_map(torch.Tensor.requires_grad_, input_obj) # Step2: fwd step output_obj = self.forward_step( @@ -572,7 +616,8 @@ def schedule_f( # We should not detach bwd LOSS pass else: - detached_output_obj = output_obj.clone().detach() + # detached_output_obj = output_obj.clone().detach() + detached_output_obj = tree_map(detach, output_obj) # Step3: send fwd # add output to send_fwd_buffer @@ -589,10 +634,10 @@ def schedule_f( else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) - # add input and output object for backward b - self.input_tensors[model_chunk_id].append(input_obj) # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) + # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) + # deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) + self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) @@ -718,7 +763,7 @@ def run_forward_only( accum_loss = None # reset accum loss at fwd end; - if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + if return_loss and self.stage_manager.is_last_stage(): accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 354f110f0b0d..47bc73db69b8 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -26,6 +26,7 @@ def __init__( pg_mesh: ProcessGroupMesh, pipeline_axis: int, enable_interleave: bool = False, + use_zbv: bool = False, num_model_chunks: int = 1, num_layers_per_stage: Optional[List[int]] = None, ) -> None: @@ -49,6 +50,7 @@ def __init__( next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.is_interleave = enable_interleave + self.use_zbv = use_zbv # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks # for shardformer, hold stage indices of model @@ -85,6 +87,15 @@ def get_stage_index( num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) stage_indices = [] + if self.use_zbv: + stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]]) + stage_indices.append( + [ + num_layers_per_stage_accumulated[2 * num_stages - stage - 1], + num_layers_per_stage_accumulated[2 * num_stages - stage], + ] + ) + return stage_indices for model_chunk in range(num_model_chunks): start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] @@ -124,7 +135,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool: if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + # use zero bubble pipeline + if self.use_zbv: + return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1 + else: + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index af610500a8eb..09876ec56f8b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -225,6 +225,7 @@ def llama_model_forward( all_self_attns += (layer_outputs[1],) if stage_manager.is_last_stage(): + print(f"{hidden_states=}") hidden_states = self.norm(hidden_states) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 67e6e92d1d36..074734f761cf 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.norm) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.norm) else: @@ -351,8 +353,10 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) + elif stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -404,7 +408,9 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.score) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.score) return held_layers diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 80b2c7961e29..d2754cbd965b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor): loss.backward() self._post_backward() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") @staticmethod diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index fdf2a497626f..2283afd1eb42 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,12 +298,12 @@ def backward(self, loss: torch.Tensor): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - grad = self.mix_precision_mixin.pre_backward_by_grad(grad) + grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 51d7d1eaaa33..2587e26944f8 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -408,7 +408,7 @@ def _add_to_bucket(self, param, group_id): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" @@ -416,7 +416,7 @@ def backward(self, loss, retain_graph=False): if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - loss.backward(retain_graph=retain_graph) + loss.backward(inputs=inputs, retain_graph=retain_graph) if not self.require_grad_sync: return @@ -427,14 +427,14 @@ def backward(self, loss, retain_graph=False): if self._overlap_communication: get_accelerator().synchronize() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + torch.autograd.backward(tensor, grad, inputs=inputs, retain_graph=retain_graph,) if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 9ad84341ac9e..2c7a3866f784 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin( sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) criterion = loss_fn - plugin = pluggin_cls(**test_config) booster = Booster(plugin=plugin) @@ -311,8 +310,16 @@ def check_output_hidden_state( ): org_hidden_state = org_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): - sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + if stage_manager: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state + elif stage_manager.is_last_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state else: sharded_hidden_state = sharded_output.last_hidden_state diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3c66f609787a..e4b28f7f9929 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -22,6 +22,7 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) +from colossalai.pipeline.schedule.v_schedule import PipelineGraph os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" @@ -31,6 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config ) + print(f"{sharded_model=}") if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() @@ -112,12 +114,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): + check_flag = False + if stage_manager is None: + check_flag = True + else: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + check_flag = True + elif stage_manager.is_last_stage(ignore_chunk=True): + check_flag = True + if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state( org_output, @@ -282,10 +292,39 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "zbv", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 0, + "initial_scale": 1, + "enable_gradient_checkpointing": False, + "parallel_output": False, + }, ], ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + if test_config.get("pp_style", None) == "zbv": + mem_f = 34 * 32 + 5 * 4 * 16 + mem_w = - 32 * 32 + mem_b = - mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage = test_config["pp_size"], + n_micro=test_config["num_microbatches"], + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue @@ -372,11 +411,11 @@ def test_llama(): spawn(check_llama, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama_3d(): - spawn(check_llama_3d, 8) +# @pytest.mark.largedist +# @rerun_if_address_is_in_use() +# @clear_cache_before_run() +# def test_llama_3d(): +# spawn(check_llama_3d, 8) if __name__ == "__main__": From e6da1aaa7b00749f995c122d2e0af95816662e4c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 03:55:12 +0000 Subject: [PATCH 02/47] fix fix --- .../pipeline/schedule/zero_bubble_pp.py | 33 ++----------------- colossalai/shardformer/modeling/llama.py | 1 - .../test_model/test_shard_llama.py | 10 +++--- 3 files changed, 8 insertions(+), 36 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 03196c48c311..9cdb3f5d1b1f 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -33,11 +33,9 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): """ if (out is None) or (not deallocate_pipeline_outputs): return - print(f"{out=}") assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) - # out.data.untyped_storage().resize_(0) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -439,7 +437,6 @@ def forward_step( internal_inputs = {} if input_obj is None else input_obj internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - # output_obj = model_chunk[model_chunk_id](input_obj) # last layer in model if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatch @@ -484,9 +481,6 @@ def backward_b_step( if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - - print(f"{input_obj=}") - print(f"{output_obj=}") if "hidden_states" in input_obj.keys(): input_obj_ = input_obj["hidden_states"] @@ -495,12 +489,6 @@ def backward_b_step( if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) - # optimizer.backward_by_grad( - # tensor=output_obj, - # grad=output_obj_grad, - # inputs=input_obj, - # retain_graph=True, - # ) else: output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( @@ -509,16 +497,6 @@ def backward_b_step( inputs=input_obj_, retain_graph=True, ) - - # if "backward_tensor_keys" not in output_obj: - # for k, grad in output_obj_grad.items(): - # optimizer.backward_by_grad(output_obj[k], grad, inputs=input_obj_, retain_graph=True) - # else: - # for k, grad in output_obj_grad.items(): - # output_obj[k].grad = grad - # for k in output_obj["backward_tensor_keys"]: - # tensor_to_backward = output_obj[k] - # optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad, inputs=input_obj_, retain_graph=True) return input_obj_.grad def backward_w_step( @@ -548,15 +526,12 @@ def backward_w_step( output_obj_grad = None if output_obj_grad is None: - print(optimizer) - # optimizer.backward(output_obj, inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=True) optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=True) else: output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad, - # inputs=list(model_chunk[model_chunk_id].parameters()), inputs=list(model_chunk.parameters()), retain_graph=False, ) @@ -616,7 +591,6 @@ def schedule_f( # We should not detach bwd LOSS pass else: - # detached_output_obj = output_obj.clone().detach() detached_output_obj = tree_map(detach, output_obj) # Step3: send fwd @@ -634,10 +608,10 @@ def schedule_f( else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + # add input and output object for backward b + self.input_tensors[model_chunk_id].append(input_obj) # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) - # deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) - self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) @@ -690,7 +664,6 @@ def schedule_b( # we save output_tensor_grad here self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # _wait_p2p(recv_bwd_handles) # Step2: bwd step input_object_grad = self.backward_b_step( model_chunk=model_chunk, @@ -763,7 +736,7 @@ def run_forward_only( accum_loss = None # reset accum loss at fwd end; - if return_loss and self.stage_manager.is_last_stage(): + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 09876ec56f8b..af610500a8eb 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -225,7 +225,6 @@ def llama_model_forward( all_self_attns += (layer_outputs[1],) if stage_manager.is_last_stage(): - print(f"{hidden_states=}") hidden_states = self.norm(hidden_states) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index e4b28f7f9929..6fe7e0d7a920 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -411,11 +411,11 @@ def test_llama(): spawn(check_llama, 4) -# @pytest.mark.largedist -# @rerun_if_address_is_in_use() -# @clear_cache_before_run() -# def test_llama_3d(): -# spawn(check_llama_3d, 8) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) if __name__ == "__main__": From 6d5b32b80b0d1852e31851dc7c315adb1c2d906a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 04:00:18 +0000 Subject: [PATCH 03/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../booster/plugin/hybrid_parallel_plugin.py | 15 +++++++---- .../pipeline/schedule/zero_bubble_pp.py | 10 ++++--- colossalai/shardformer/policies/llama.py | 4 +-- colossalai/zero/gemini/gemini_optimizer.py | 4 ++- colossalai/zero/low_level/low_level_optim.py | 7 ++++- .../test_model/test_shard_llama.py | 26 +++++++++---------- 6 files changed, 41 insertions(+), 25 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1323d14b320c..82ad5bca84fd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -39,7 +39,6 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle -from colossalai.pipeline.schedule.v_schedule import PipelineGraph from .pp_plugin_base import PipelinePluginBase @@ -1031,7 +1030,9 @@ def __init__( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" - assert not pp_style == "zbv" or scheduler_nodes is not None, f"scheduler_nodes must not be None when using zero bubble pipeline." + assert ( + not pp_style == "zbv" or scheduler_nodes is not None + ), f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1096,8 +1097,12 @@ def __init__( assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" - assert pp_style in ["interleaved", "zbv"] or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" - assert pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2, "num_model_chunks must be 2 when using zero bubble pipeline" + assert ( + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -1133,7 +1138,7 @@ def __init__( elif pp_style == "zbv": self.scheduler = ZeroBubbleVPipeScheduler( stage_manager=self.stage_manager, - schedule = scheduler_nodes, + schedule=scheduler_nodes, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, microbatch_size=microbatch_size, diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 9cdb3f5d1b1f..cc4f700bac6a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,7 +12,7 @@ from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device, model_forward +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -35,7 +35,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): return assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + out.data = torch.empty( + (1,), + device=out.device, + dtype=out.dtype, + ) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -481,7 +485,7 @@ def backward_b_step( if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - + if "hidden_states" in input_obj.keys(): input_obj_ = input_obj["hidden_states"] else: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 074734f761cf..60da448d8767 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -262,7 +262,7 @@ def get_held_layers(self) -> List[Module]: for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.norm) + held_layers.append(module.norm) elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.norm) @@ -356,7 +356,7 @@ def get_held_layers(self) -> List[Module]: if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) elif stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 2283afd1eb42..ccd4634b5fe2 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,7 +298,9 @@ def backward(self, loss: torch.Tensor): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False): + def backward_by_grad( + self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False + ): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 2587e26944f8..9cc44c7538dd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -434,7 +434,12 @@ def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bo if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad, inputs=inputs, retain_graph=retain_graph,) + torch.autograd.backward( + tensor, + grad, + inputs=inputs, + retain_graph=retain_graph, + ) if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 6fe7e0d7a920..2a21262e9c3f 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -7,6 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter @@ -22,7 +23,6 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) -from colossalai.pipeline.schedule.v_schedule import PipelineGraph os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" @@ -122,7 +122,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager.is_first_stage(ignore_chunk=True): check_flag = True elif stage_manager.is_last_stage(ignore_chunk=True): - check_flag = True + check_flag = True if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 @@ -311,18 +311,18 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") if test_config.get("pp_style", None) == "zbv": mem_f = 34 * 32 + 5 * 4 * 16 - mem_w = - 32 * 32 - mem_b = - mem_w - mem_f + mem_w = -32 * 32 + mem_b = -mem_w - mem_f scheduler_nodes = PipelineGraph( - n_stage = test_config["pp_size"], - n_micro=test_config["num_microbatches"], - f_cost=1000, - b_cost=1000, - w_cost=1000, - c_cost=1, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, + n_stage=test_config["pp_size"], + n_micro=test_config["num_microbatches"], + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, ).get_v_schedule() test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): From 4404775b9784d8bf08cc07ea74292cb63812a202 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 10:55:44 +0000 Subject: [PATCH 04/47] fix --- .../naive_amp/mixed_precision_optimizer.py | 4 ++-- .../booster/mixed_precision/fp16_torch.py | 4 ++-- .../booster/plugin/hybrid_parallel_plugin.py | 10 ++++++---- colossalai/interface/optimizer.py | 4 ++-- .../pipeline/schedule/zero_bubble_pp.py | 19 ++++++------------- .../test_model/test_shard_llama.py | 5 ++--- 6 files changed, 20 insertions(+), 26 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 700f80336cf0..121c92011cc5 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -85,9 +85,9 @@ def __init__( master_params.append(master_p) group["params"] = master_params - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): loss = self.mixed_precision.pre_backward(loss) - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index c757a878d97a..a85d9f808546 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -46,9 +46,9 @@ def __init__( growth_interval=growth_interval, ) - def backward(self, loss: Tensor, *args, **kwargs) -> None: + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None: scaled_loss = self.scale_loss(loss) - scaled_loss.backward(*args, **kwargs) + scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def step(self, *args, **kwargs) -> Optional[float]: out = self.scaler.step(self.optim, *args, **kwargs) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1323d14b320c..791281506bff 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -289,7 +289,7 @@ def __init__( self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 super().__init__(optim) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -307,7 +307,7 @@ def backward(self, loss: Tensor, *args, **kwargs): """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -513,7 +513,7 @@ def __init__( max_norm=max_norm, ) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -530,7 +530,7 @@ def backward(self, loss: Tensor, *args, **kwargs): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1104,6 +1104,8 @@ def __init__( assert ( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + if pp_style == "zbv": + self.logger.warning("""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""") self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index a236434a55d6..c8cf3ec21360 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs): """ self.optim.zero_grad(*args, **kwargs) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): """ Performs a backward pass on the loss. """ - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 9cdb3f5d1b1f..6c30b32dad38 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -526,7 +526,7 @@ def backward_w_step( output_obj_grad = None if output_obj_grad is None: - optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=True) + optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( @@ -587,31 +587,27 @@ def schedule_f( accum_loss=accum_loss, outputs=outputs, ) - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # We should not detach bwd LOSS - pass - else: - detached_output_obj = tree_map(detach, output_obj) # Step3: send fwd # add output to send_fwd_buffer if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): + detached_output_obj = tree_map(detach, output_obj) self.local_send_forward_buffer.append(detached_output_obj) + # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj + # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) else: - self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + self.send_forward_buffer[model_chunk_id].append(output_obj) else: # is first stage; end of fwd; append LOSS to local_send_backward_buffer if self.stage_manager.is_first_stage(ignore_chunk=True): pass else: - self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + self.send_forward_buffer[model_chunk_id].append(output_obj) # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) - # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) @@ -622,9 +618,6 @@ def schedule_b( model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - # input_obj: Optional[dict], - # output_obj: Union[dict, torch.Tensor], - # output_obj_grad: Optional[dict], ): """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 6fe7e0d7a920..04b962a513db 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -32,10 +32,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config ) - print(f"{sharded_model=}") if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() - sharded_model.unwrap().gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False}) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -302,7 +301,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "zero_stage": 0, "initial_scale": 1, - "enable_gradient_checkpointing": False, + "enable_gradient_checkpointing": True, "parallel_output": False, }, ], From 3feda3ba2ebd3f1df9db6c4dda15aabbdabd28f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:58:15 +0000 Subject: [PATCH 05/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 +++- tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3a78db6ac4fa..5d114ab9c315 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1110,7 +1110,9 @@ def __init__( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" if pp_style == "zbv": - self.logger.warning("""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""") + self.logger.warning( + """the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""" + ) self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b74365c7f2af..d925687cd875 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -34,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() - sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False}) + sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster From 6d0122d50a3e9945718bf1d4edae4f7ac2eff6cf Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 11:29:46 +0000 Subject: [PATCH 06/47] fix --- colossalai/pipeline/schedule/zero_bubble_pp.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 3b6d88a3806b..3700b8c95f7a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -35,11 +35,7 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): return assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty( - (1,), - device=out.device, - dtype=out.dtype, - ) + out.data.untyped_storage().resize_(0) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -336,6 +332,7 @@ def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) + tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_tensor) return send_handles else: @@ -354,6 +351,7 @@ def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_tensor, prev_rank) + tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_tensor) return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: @@ -597,10 +595,11 @@ def schedule_f( if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): - detached_output_obj = tree_map(detach, output_obj) + output_obj_clone = tree_map(torch.Tensor.clone, output_obj) + detached_output_obj = tree_map(detach, output_obj_clone) self.local_send_forward_buffer.append(detached_output_obj) # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) + tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) else: self.send_forward_buffer[model_chunk_id].append(output_obj) else: From 9802a7d890c1a6d89066826fd1043f5ee689b78a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 20:06:41 +0800 Subject: [PATCH 07/47] Update zero_bubble_pp.py --- colossalai/pipeline/schedule/zero_bubble_pp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 3700b8c95f7a..759a6144c169 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -484,11 +484,7 @@ def backward_b_step( # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - if "hidden_states" in input_obj.keys(): - input_obj_ = input_obj["hidden_states"] - else: - input_obj_ = input_obj["input_ids"] - + input_obj_ = input_obj["hidden_states"] if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: From 9c59e6cae311c7f6cca33d2ddcb6c1b7ac0390d2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Sep 2024 02:51:28 +0000 Subject: [PATCH 08/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 24 +++++++++++-------- .../test_model/test_shard_llama.py | 2 +- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 759a6144c169..5b4092bcefe3 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -11,6 +11,8 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.zero.low_level import LowLevelZeroOptimizer +from contextlib import nullcontext from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -485,16 +487,18 @@ def backward_b_step( assert output_obj_grad is None input_obj_ = input_obj["hidden_states"] - if output_obj_grad is None: - optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) - else: - output_obj_ = output_obj["hidden_states"] - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad, - inputs=input_obj_, - retain_graph=True, - ) + ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext() + with ctx: + if output_obj_grad is None: + optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) + else: + output_obj_ = output_obj["hidden_states"] + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad, + inputs=input_obj_, + retain_graph=True, + ) return input_obj_.grad def backward_w_step( diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d925687cd875..9f67ecbea687 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -299,7 +299,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 4, "enable_all_optimization": False, "precision": "fp16", - "zero_stage": 0, + "zero_stage": 1, "initial_scale": 1, "enable_gradient_checkpointing": True, "parallel_output": False, From 37d9623001f3d10e7feef42390092d0297f5be73 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Sep 2024 06:13:22 +0000 Subject: [PATCH 09/47] fix-ci --- .github/workflows/build_on_pr.yml | 2 +- .github/workflows/build_on_schedule.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 58cd8826809a..79d758c87976 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . pip install --no-cache-dir -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index fc688a71bd92..e7b5063279eb 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -55,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install --no-cache-dir -r requirements/requirements-test.txt From 5965f8bfdea869235fa17f63c370b28a5e79c07c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Sep 2024 08:21:36 +0000 Subject: [PATCH 10/47] fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix --- colossalai/pipeline/schedule/zero_bubble_pp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 5b4092bcefe3..7cff690d248e 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -11,9 +12,6 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.zero.low_level import LowLevelZeroOptimizer -from contextlib import nullcontext - from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -487,7 +485,13 @@ def backward_b_step( assert output_obj_grad is None input_obj_ = input_obj["hidden_states"] - ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext() + + # Attempt to disable gradient synchronization when using the LowLevelZeroPlugin. + try: + ctx = optimizer.no_sync() + except Exception as e: + ctx = nullcontext() + with ctx: if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) From f99fc6db5158323a75d1cc60c7ecc5847df8facf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 08:26:56 +0000 Subject: [PATCH 11/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/pipeline/schedule/zero_bubble_pp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7cff690d248e..2525e8b16bf4 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,6 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager + from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -489,7 +490,7 @@ def backward_b_step( # Attempt to disable gradient synchronization when using the LowLevelZeroPlugin. try: ctx = optimizer.no_sync() - except Exception as e: + except Exception: ctx = nullcontext() with ctx: From e169f979c1db28d1482a6c5adb4c2a046c67adfb Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:30:21 +0000 Subject: [PATCH 12/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 36 +- .../test_schedule/test_zerobubble_pp.py | 857 +++--------------- tests/test_shardformer/test_model/_utils.py | 2 +- .../test_model/test_shard_llama.py | 272 +++--- 4 files changed, 304 insertions(+), 863 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7cff690d248e..097b1179693a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -157,8 +157,8 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.num_microbatch % self.stage_manager.num_stages == 0 ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" - if self.forward_only: - self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + # if self.forward_only: + # self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) # if self.batch_size != self.last_batch_size: # self.enable_metadata_cache = False @@ -435,15 +435,24 @@ def forward_step( micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) with self.stage_manager.switch_model_chunk_id(model_chunk_id): - # fwd calculate - internal_inputs = {} if input_obj is None else input_obj - internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + if isinstance(model_chunk, ModuleList): + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + else: + # fwd calculate + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + # last layer in model if self.stage_manager.is_last_stage(): + print(f"aaaaa{output_obj=}") loss = criterion(output_obj, micro_batch) / self.num_microbatch + print(f"bbbb{loss=}") if accum_loss is not None: - accum_loss.add_(loss.detach()) + print(f"accum_loss{accum_loss=}") + if not torch.isinf(loss): + accum_loss.add_(loss.detach()) + print(f"add accum_loss{accum_loss=}") if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss @@ -497,6 +506,7 @@ def backward_b_step( optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: output_obj_ = output_obj["hidden_states"] + print(f"{model_chunk_id=}, {output_obj_grad=}") optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad, @@ -531,6 +541,8 @@ def backward_w_step( # loss backward; output_obj is loss output_obj_grad = None + if isinstance(model_chunk, ModuleList): + model_chunk = model_chunk[model_chunk_id] if output_obj_grad is None: optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: @@ -742,9 +754,10 @@ def run_forward_only( outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None # while we still have schedules_node in self.schedules - for it in range(len(self.schedules)): - scheduled_node = self.schedules[it] - + # while we still have schedules_node in self.schedules + schedules = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + for it in range(len(schedules)): + scheduled_node = schedules[it] if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: # communication communication_func = self.communication_map[scheduled_node.type] @@ -761,6 +774,7 @@ def run_forward_only( # return loss & output if outputs is not None: outputs = merge_batch(outputs) + print(f"{accum_loss=}") return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( @@ -857,6 +871,6 @@ def forward_backward_step( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) - self.assert_buffer_empty() + # self.assert_buffer_empty() return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 825c192d8fd5..0bde70ec6c16 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,5 +1,6 @@ -from copy import deepcopy -from typing import Tuple +import copy +from functools import partial +from types import MethodType import pytest import torch @@ -10,18 +11,20 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper -from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + +NUM_LAYER = 8 +DIM = 4 class MlpModel(nn.Module): - def __init__(self, in_dim, out_dim, num_layers): + def __init__(self): super().__init__() - self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)]) def forward(self, x): for layer in self.layers: @@ -29,741 +32,163 @@ def forward(self, x): return x -def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: - num_params = 0 - num_params_trainable = 0 - for p in model.parameters(): - num_params += p.numel() - if p.requires_grad: - num_params_trainable += p.numel() - return num_params, num_params_trainable - - -# 1) Test manual v_schedule with multiple microbatch -@parameterize( - "test_config", - [ - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - ], -) -def run_fwd_bwd_iter_input(test_config): - # init dist - rank = dist.get_rank() - pp_size = test_config["pp_size"] - pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = test_config["num_microbatches"] - num_model_chunk = test_config["num_model_chunk"] - # stage_manager +def pp_linear_fwd( + forward, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, +): + with stage_mgr.switch_model_chunk_id(model_chunk_id): + if stage_mgr.is_first_stage(): + return {"hidden_states": forward(data)} + elif stage_mgr.is_last_stage(): + return forward(hidden_states) + else: + return {"hidden_states": forward(hidden_states)} + + +def run_pp( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, +): + """ + This test is to examine the correctness of interleaved 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + # create model + seed_all(1453) + torch_model = MlpModel().cuda() + pp_model = copy.deepcopy(torch_model).cuda() + + pg_mesh = ProcessGroupMesh(world_size) stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + pg_mesh, pipeline_axis=0, enable_interleave=True, use_zbv=True, num_model_chunks=num_model_chunk ) # schedule list - zbv_schedule = [ - # stage 0 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), - ], - # stage 1 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), - ], - # stage 2 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3), - ], - # stage 3 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3), - ], - ] - - scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? - stage_manager=stage_manager, - num_model_chunks=pp_size, - num_microbatch=num_microbatch, - overlap_p2p=False, - ) - - # loss func - def criterion(x, *args, **kwargs): - return (x * x).mean() - - # init model and input - batch_size = 4 - num_layers = 8 - in_dim = out_dim = 8 - print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") - model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - - input_base = [t.clone() for t in data_iter] - model_base = deepcopy(model) - - if rank == 0: - # layer 0 & 7 to chunk 0 on rank0 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 0 or idx == 7: - local_chunk.append(sub_model) - elif rank == 1: - # layer 1 & 6 to chunk 1 on rank1 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 1 or idx == 6: - local_chunk.append(sub_model) - elif rank == 2: - # layer 2 & 5 to chunk 2 on rank2 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 2 or idx == 5: - local_chunk.append(sub_model) - else: - # layer 3 & 4 to chunk 3 on rank3 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 3 or idx == 4: - local_chunk.append(sub_model) - # init optimizer - optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) - optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) - - print( - f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" - ) - - torch.cuda.synchronize() - result = scheduler.forward_backward_step( - model_chunk=local_chunk, - data_iter=iter(data_iter), - criterion=criterion, - optimizer=optimizer_pp, - return_loss=True, - return_outputs=True, - ) - - optimizer_pp.step() - - ########################## - # Fwd bwd for base - ########################## - # fwd & bwd - output_base = model_base(input_base[0]) - loss_base = criterion(output_base) - loss_base.backward() - optimizer_base.step() - print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - ########################## - # assert weight - ########################## - if rank == 0: - # layer 0 - assert_close(local_chunk[0].weight, model_base.layers[0].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) - # layer 7 - assert_close(local_chunk[1].weight, model_base.layers[7].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) - if rank == 1: - # layer 1 - assert_close(local_chunk[0].weight, model_base.layers[1].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) - # layer 6 - assert_close(local_chunk[1].weight, model_base.layers[6].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: - # layer 2 - assert_close(local_chunk[0].weight, model_base.layers[2].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) - # layer 5 - assert_close(local_chunk[1].weight, model_base.layers[5].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: - # layer 3 - assert_close(local_chunk[0].weight, model_base.layers[3].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) - # layer 4 - assert_close(local_chunk[1].weight, model_base.layers[4].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) - - -# 2) add optimizer base 1) -@parameterize( - "test_config", - [ - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 8, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - ], -) -def run_fwd_bwd_vschedule_with_optim(test_config): - # init dist - rank = dist.get_rank() - pp_size = test_config["pp_size"] - pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = test_config["num_microbatches"] - num_model_chunk = test_config["num_model_chunk"] - # stage_manager - stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk - ) - - h, a, s = 4096, 32, 1024 - mem_f = 34 * h + 5 * a * s - mem_w = -32 * h + mem_f = 34 * 32 + 5 * 4 * 16 + mem_w = -32 * 32 mem_b = -mem_w - mem_f - graph = PipelineGraph( - n_stage=pp_size, - n_micro=num_microbatch, - f_cost=1, - b_cost=1, - w_cost=1, + scheduler_nodes = PipelineGraph( + n_stage=4, + n_micro=12, + f_cost=1000, + b_cost=1000, + w_cost=1000, c_cost=1, f_mem=mem_f, b_mem=mem_b, w_mem=mem_w, - # max_mem=mem_f * (p * 2 + m_offset), - ) - - zbv_schedule = graph.get_v_schedule() - - scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + ).get_v_schedule() + schedule = ZeroBubbleVPipeScheduler( stage_manager=stage_manager, + schedule=scheduler_nodes, num_model_chunks=num_model_chunk, num_microbatch=num_microbatch, - overlap_p2p=False, ) - # init loss func + sharded_model = torch.nn.ModuleList() + for idx, sub_model in enumerate(pp_model.layers): + if idx == rank or (NUM_LAYER-idx-1) == rank: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)), + sub_model._forward, + ) + sharded_model.append(sub_model.cuda()) + assert len(sharded_model) == num_model_chunk, f"{len(sharded_model)}, {num_model_chunk}, num_model_chunk is not correct" + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5)) + + # create data + seed_all(115) + input_list = [torch.rand(batch_size, DIM).cuda()] + dist.all_reduce(input_list[0]) + def criterion(x, *args, **kwargs): return (x * x).mean() - # init model and input - batch_size = test_config["batch_size"] - num_layers = 8 - assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" - in_dim = out_dim = 4096 - before_init_memory = torch.cuda.memory_allocated() / 1024**3 - print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") - model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - - input_base = [t.clone() for t in data_iter] - model_base = deepcopy(model) - - if rank == 0: - # layer 0 & 7 to chunk 0 on rank0 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 0 or idx == 7: - local_chunk.append(sub_model) - elif rank == 1: - # layer 1 & 6 to chunk 1 on rank1 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 1 or idx == 6: - local_chunk.append(sub_model) - elif rank == 2: - # layer 2 & 5 to chunk 2 on rank2 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 2 or idx == 5: - local_chunk.append(sub_model) - else: - # layer 3 & 4 to chunk 3 on rank3 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 3 or idx == 4: - local_chunk.append(sub_model) - - # init optimizer - optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) - optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) - - after_init_memory = torch.cuda.memory_allocated() / 1024**3 - print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") - - torch.cuda.synchronize() - result = scheduler.forward_backward_step( - model_chunk=local_chunk, - data_iter=iter(data_iter), - criterion=criterion, - optimizer=optimizer_pp, - return_loss=True, - return_outputs=True, - ) - - optimizer_pp.step() - - after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 - - # assert memory - if rank != 0: - # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 - # output hid_dim * hid_dim * 4(fp32) / 1024**3 - # optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 - print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}") - assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3) - else: - # rank0 will also hold output; - print( - f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) + + # check loss + if stage_manager.is_first_stage(ignore_chunk=True): + assert_close(torch_loss, pp_ret["loss"]) + + # check gradients + for i in range(num_model_chunk): + # idx = world_size * i + rank + if i == 0: + idx = rank + else: + idx = world_size * 2 - rank - 1 + print(f"{i=}, {idx=}, {rank=}, {torch_model.layers[idx].weight.grad=}, {sharded_model[i].weight.grad=}") + assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) + + # step + torch_optimizer.step() + pp_optimizer.step() + pp_optimizer.zero_grad() + + # check updated param + for i in range(num_model_chunk): + # idx = world_size * i + rank + if i == 0: + idx = rank + else: + idx = world_size * 2 - rank - 1 + assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) + assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) + + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) - assert round((after_pp_step_memory - after_init_memory), 5) <= round( - (in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 - ) - - ########################## - # Fwd bwd for base - ########################## - # fwd & bwd - output_base = model_base(input_base[0]) - loss_base = criterion(output_base) - loss_base.backward() - optimizer_base.step() - - ########################## - # assert loss & output - ########################## - # only chunk 1 stage 0 hold loss and output - if rank == 0: - assert_close(result["loss"], loss_base) - assert_close(result["outputs"], output_base) + if stage_manager.is_first_stage(ignore_chunk=True): + print(f"{torch_loss=}, {pp_ret['loss']}") + assert_close(torch_loss, pp_ret["loss"]) - # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") - ########################## - # assert weight - ########################## - if rank == 0: - # layer 0 - assert_close(local_chunk[0].weight, model_base.layers[0].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) - # layer 7 - assert_close(local_chunk[1].weight, model_base.layers[7].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) - if rank == 1: - # layer 1 - assert_close(local_chunk[0].weight, model_base.layers[1].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) - # layer 6 - assert_close(local_chunk[1].weight, model_base.layers[6].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: - # layer 2 - assert_close(local_chunk[0].weight, model_base.layers[2].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) - # layer 5 - assert_close(local_chunk[1].weight, model_base.layers[5].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: - # layer 3 - assert_close(local_chunk[0].weight, model_base.layers[3].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) - # layer 4 - assert_close(local_chunk[1].weight, model_base.layers[4].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) - - ########################## - # assert optim state - ########################## - optim_base_state = optimizer_base.state_dict()["state"] - optim_pp_state = optimizer_pp.state_dict()["state"] - optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] - optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] - # if rank == 0: - # print(f"optim_base_state {optim_base_state}") - - # assert param group - for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): - if key_base == key_pp: - if key_base != "params": - assert val_base == val_pp + for layer in sharded_model: + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None else: - # BUG: - # param_base: [0, 1, 2, 3, 4, 5, 6, 7]; - # params pp: [0, 1]; - assert val_base[:2] == val_pp - - # assert state - assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) - assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"]) - - -# TODO:4) support Hybrid base 3) -def run_with_hybridplugin(test_config): - pass - - -# TODO:5) support MoEHybrid base 3) -@parameterize( - "test_config", - [ - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - ], -) -def run_with_moehybridplugin(test_config): - model_zoo.get_sub_registry("transformers_bert") - test_config["use_lazy_init"] = False - test_config["initial_scale"] = 2**16 - model_list = [ - "transformers_bert", - ] - - -# TODO:6) support booster & Hybrid base 4) - -# TODO:7) support booster & MoEHybrid base 4) - - -def run_dist(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # run_fwd_bwd_iter_input() - run_fwd_bwd_vschedule_with_optim() - # run_with_moehybridplugin() + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) @pytest.mark.dist +@pytest.mark.parametrize("num_microbatch", [12]) +@pytest.mark.parametrize("batch_size", [24]) +@pytest.mark.parametrize("num_model_chunk", [2]) @rerun_if_address_is_in_use() -def test_pp(): +def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): + assert NUM_LAYER % num_model_chunk == 0 spawn( - run_dist, - nprocs=4, + run_pp, + nprocs=NUM_LAYER // num_model_chunk, + num_microbatch=num_microbatch, + batch_size=batch_size, + num_model_chunk=num_model_chunk, ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=12, num_model_chunk=2) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 2c7a3866f784..dbefe872cbc8 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -397,7 +397,7 @@ def get_grad_tensors_for_check( pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - + print(f"grad_to_check {shard_grad=}") grad_to_check[suffix] = { "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 9f67ecbea687..81892817cec4 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -136,6 +136,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, rtol=rtol, shard_config=booster.plugin.shard_config, ) + print(f"check_loss") check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): @@ -143,6 +144,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 + print(f"check_weight") check_weight( llama_model, shard_llama_model, @@ -162,135 +164,135 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # Double Ring Attention - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 4, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - "inner_ring_size": 2, - }, - # Ring Attention + PP - { - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - # Ring Attention + TP - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + TP - "tp_size": 2, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + PP - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "use_lazy_init": False, - "precision": "fp32", - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, + # # Double Ring Attention + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 4, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # "inner_ring_size": 2, + # }, + # # Ring Attention + PP + # { + # "tp_size": 1, + # "pp_size": 2, + # "sp_size": 2, + # "num_microbatches": 2, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # # Ring Attention + TP + # { + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { # Ulysess + TP + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { # Ulysess + PP + # "tp_size": 1, + # "pp_size": 2, + # "sp_size": 2, + # "num_microbatches": 2, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "split_gather", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 4, + # "use_lazy_init": False, + # "precision": "fp32", + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, { "tp_size": 2, "pp_size": 2, @@ -301,7 +303,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "zero_stage": 1, "initial_scale": 1, - "enable_gradient_checkpointing": True, + "enable_gradient_checkpointing": False, "parallel_output": False, }, ], @@ -410,11 +412,11 @@ def test_llama(): spawn(check_llama, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama_3d(): - spawn(check_llama_3d, 8) +# @pytest.mark.largedist +# @rerun_if_address_is_in_use() +# @clear_cache_before_run() +# def test_llama_3d(): +# spawn(check_llama_3d, 8) if __name__ == "__main__": From 7f78272878eca31fb66739cc734233632e1f5db6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 02:32:55 +0000 Subject: [PATCH 13/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/pipeline/schedule/zero_bubble_pp.py | 16 ++++++++-------- .../test_schedule/test_zerobubble_pp.py | 8 +++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c76656200a2f..559874e210b2 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -160,13 +160,13 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) # if self.forward_only: # self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 - # NOTE: disable metadata cache when batch size changes (not valid anymore) - # if self.batch_size != self.last_batch_size: - # self.enable_metadata_cache = False - # self.send_tensor_metadata = True - # self.send_grad_metadata = True - # self.tensor_metadata_recv = None - # self.grad_metadata_recv = None + # NOTE: disable metadata cache when batch size changes (not valid anymore) + # if self.batch_size != self.last_batch_size: + # self.enable_metadata_cache = False + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None self.last_batch_size = self.batch_size @@ -443,7 +443,7 @@ def forward_step( internal_inputs = {} if input_obj is None else input_obj internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - + # last layer in model if self.stage_manager.is_last_stage(): print(f"aaaaa{output_obj=}") diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 0bde70ec6c16..31531069386e 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -11,7 +11,7 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper -from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -96,14 +96,16 @@ def run_pp( sharded_model = torch.nn.ModuleList() for idx, sub_model in enumerate(pp_model.layers): - if idx == rank or (NUM_LAYER-idx-1) == rank: + if idx == rank or (NUM_LAYER - idx - 1) == rank: sub_model._forward = sub_model.forward sub_model.forward = MethodType( partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)), sub_model._forward, ) sharded_model.append(sub_model.cuda()) - assert len(sharded_model) == num_model_chunk, f"{len(sharded_model)}, {num_model_chunk}, num_model_chunk is not correct" + assert ( + len(sharded_model) == num_model_chunk + ), f"{len(sharded_model)}, {num_model_chunk}, num_model_chunk is not correct" # create optimizer torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) From 885ace714ccd04fddc25e178fba2ee8d1e4bb27b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:42:14 +0000 Subject: [PATCH 14/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 6 - tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_shard_llama.py | 272 +++++++++--------- 3 files changed, 135 insertions(+), 144 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c76656200a2f..201298cd65c2 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -446,14 +446,10 @@ def forward_step( # last layer in model if self.stage_manager.is_last_stage(): - print(f"aaaaa{output_obj=}") loss = criterion(output_obj, micro_batch) / self.num_microbatch - print(f"bbbb{loss=}") if accum_loss is not None: - print(f"accum_loss{accum_loss=}") if not torch.isinf(loss): accum_loss.add_(loss.detach()) - print(f"add accum_loss{accum_loss=}") if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss @@ -507,7 +503,6 @@ def backward_b_step( optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: output_obj_ = output_obj["hidden_states"] - print(f"{model_chunk_id=}, {output_obj_grad=}") optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad, @@ -775,7 +770,6 @@ def run_forward_only( # return loss & output if outputs is not None: outputs = merge_batch(outputs) - print(f"{accum_loss=}") return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index dbefe872cbc8..5c141e8f5cf1 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -397,7 +397,6 @@ def get_grad_tensors_for_check( pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - print(f"grad_to_check {shard_grad=}") grad_to_check[suffix] = { "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 81892817cec4..9f67ecbea687 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -136,7 +136,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, rtol=rtol, shard_config=booster.plugin.shard_config, ) - print(f"check_loss") check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): @@ -144,7 +143,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - print(f"check_weight") check_weight( llama_model, shard_llama_model, @@ -164,135 +162,135 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # # Double Ring Attention - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 4, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # "inner_ring_size": 2, - # }, - # # Ring Attention + PP - # { - # "tp_size": 1, - # "pp_size": 2, - # "sp_size": 2, - # "num_microbatches": 2, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # # Ring Attention + TP - # { - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { # Ulysess + TP - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { # Ulysess + PP - # "tp_size": 1, - # "pp_size": 2, - # "sp_size": 2, - # "num_microbatches": 2, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "split_gather", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 4, - # "use_lazy_init": False, - # "precision": "fp32", - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, + # Double Ring Attention + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + "inner_ring_size": 2, + }, + # Ring Attention + PP + { + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + # Ring Attention + TP + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + TP + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + PP + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -303,7 +301,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "zero_stage": 1, "initial_scale": 1, - "enable_gradient_checkpointing": False, + "enable_gradient_checkpointing": True, "parallel_output": False, }, ], @@ -412,11 +410,11 @@ def test_llama(): spawn(check_llama, 4) -# @pytest.mark.largedist -# @rerun_if_address_is_in_use() -# @clear_cache_before_run() -# def test_llama_3d(): -# spawn(check_llama_3d, 8) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) if __name__ == "__main__": From 629c76df00a4f52fa5becae5f82f7e8ce5e6b39e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:47:33 +0000 Subject: [PATCH 15/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 2 +- .../test_schedule/test_zerobubble_pp.py | 28 +++++++------------ 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 20ea46b77adc..982c95c6a727 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -866,6 +866,6 @@ def forward_backward_step( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) - # self.assert_buffer_empty() + self.assert_buffer_empty() return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 31531069386e..6881893f157a 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -156,24 +156,16 @@ def criterion(x, *args, **kwargs): assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) - # forward only - with torch.no_grad(): - torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output) - - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True - ) - if stage_manager.is_first_stage(ignore_chunk=True): - print(f"{torch_loss=}, {pp_ret['loss']}") - assert_close(torch_loss, pp_ret["loss"]) - - for layer in sharded_model: - if layer.weight.grad is None: - assert layer.weight.grad is None and layer.bias.grad is None - else: - assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + # forward one step + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True + ) + if stage_manager.is_first_stage(ignore_chunk=True): + print(f"{torch_loss=}, {pp_ret['loss']}") + assert_close(torch_loss, pp_ret["loss"]) @pytest.mark.dist From 1fcc3a612b42567a374b6e95608fc7e74f8c82d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 02:48:50 +0000 Subject: [PATCH 16/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 6881893f157a..9e7128af4b08 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -160,9 +160,7 @@ def criterion(x, *args, **kwargs): torch_output = torch_model(input_list[0]) torch_loss = criterion(torch_output) - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True - ) + pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) if stage_manager.is_first_stage(ignore_chunk=True): print(f"{torch_loss=}, {pp_ret['loss']}") assert_close(torch_loss, pp_ret["loss"]) From 88b068f16fae694454e8f90312edeb0f98827c93 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:52:22 +0000 Subject: [PATCH 17/47] fix --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 6881893f157a..ce1e869c892c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -132,7 +132,6 @@ def criterion(x, *args, **kwargs): # check gradients for i in range(num_model_chunk): - # idx = world_size * i + rank if i == 0: idx = rank else: @@ -148,7 +147,6 @@ def criterion(x, *args, **kwargs): # check updated param for i in range(num_model_chunk): - # idx = world_size * i + rank if i == 0: idx = rank else: @@ -160,9 +158,7 @@ def criterion(x, *args, **kwargs): torch_output = torch_model(input_list[0]) torch_loss = criterion(torch_output) - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True - ) + pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) if stage_manager.is_first_stage(ignore_chunk=True): print(f"{torch_loss=}, {pp_ret['loss']}") assert_close(torch_loss, pp_ret["loss"]) From 050971275f5b55c04f3b45d0aaa24be91dc8fe08 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 04:15:12 +0000 Subject: [PATCH 18/47] fix --- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 6 +++--- .../test_pipeline_utils/test_t5_pipeline_utils.py | 1 + .../test_pipeline_utils/test_whisper_pipeline_utils.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 36973b240896..6fbb72ba3e5b 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -278,7 +278,7 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: @@ -300,7 +300,7 @@ def __init__( if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -309,7 +309,7 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index e2f71ff89221..f79bdeb3a96f 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -15,6 +15,7 @@ def __init__(self): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index d39c5ea91dd4..722b8fd7cfae 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -15,6 +15,7 @@ def __init__(self): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): From 28f581f1e29999516cdd851e2e3a5ae73bf80baf Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 10:22:16 +0000 Subject: [PATCH 19/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 59 ++++++++++--------- .../test_schedule/test_zerobubble_pp.py | 27 ++++++--- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 982c95c6a727..8ee41ae753c8 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -5,7 +5,7 @@ import torch import torch.cuda from torch.nn import Module, ModuleList -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_map from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper @@ -109,22 +109,15 @@ def _free_buffers(self): def assert_buffer_empty(self): # assert buuffer is empty at end - assert len(self.input_tensors[0]) == 0 - assert len(self.input_tensors[1]) == 0 - assert len(self.output_tensors[0]) == 0 - assert len(self.output_tensors[1]) == 0 - assert len(self.output_tensors_dw[0]) == 0 - assert len(self.output_tensors_dw[1]) == 0 - assert len(self.output_tensors_grad_dw[0]) == 0 - assert len(self.output_tensors_grad_dw[1]) == 0 - assert len(self.send_forward_buffer[0]) == 0 - assert len(self.send_forward_buffer[1]) == 0 - assert len(self.recv_forward_buffer[0]) == 0 - assert len(self.recv_forward_buffer[1]) == 0 - assert len(self.send_backward_buffer[0]) == 0 - assert len(self.send_backward_buffer[1]) == 0 - assert len(self.recv_backward_buffer[0]) == 0 - assert len(self.recv_backward_buffer[1]) == 0 + for chunk_id in [0, 1]: + assert len(self.input_tensors[chunk_id]) == 0, f"{self.input_tensors=}" + assert len(self.output_tensors[chunk_id]) == 0 + assert len(self.output_tensors_dw[chunk_id]) == 0 + assert len(self.output_tensors_grad_dw[chunk_id]) == 0 + assert len(self.send_forward_buffer[chunk_id]) == 0 + assert len(self.recv_forward_buffer[chunk_id]) == 0 + assert len(self.send_backward_buffer[chunk_id]) == 0 + assert len(self.recv_backward_buffer[chunk_id]) == 0 assert len(self.local_send_forward_buffer) == 0 assert len(self.local_send_backward_buffer) == 0 @@ -158,8 +151,8 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.num_microbatch % self.stage_manager.num_stages == 0 ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" - # if self.forward_only: - # self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + if self.forward_only: + self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) # if self.batch_size != self.last_batch_size: # self.enable_metadata_cache = False @@ -490,8 +483,6 @@ def backward_b_step( # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - input_obj_ = input_obj["hidden_states"] - # Attempt to disable gradient synchronization when using the LowLevelZeroPlugin. try: ctx = optimizer.no_sync() @@ -499,17 +490,27 @@ def backward_b_step( ctx = nullcontext() with ctx: + input_obj_, tree_spec = tree_flatten(input_obj) if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: - output_obj_ = output_obj["hidden_states"] + output_obj_, _ = tree_flatten(output_obj) + output_obj_grad_, _ = tree_flatten(output_obj_grad) optimizer.backward_by_grad( tensor=output_obj_, - grad=output_obj_grad, + grad=output_obj_grad_, inputs=input_obj_, retain_graph=True, ) - return input_obj_.grad + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad def backward_w_step( self, @@ -539,13 +540,15 @@ def backward_w_step( if isinstance(model_chunk, ModuleList): model_chunk = model_chunk[model_chunk_id] + if output_obj_grad is None: optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: - output_obj_ = output_obj["hidden_states"] + output_obj_, _ = tree_flatten(output_obj) + output_obj_grad_, _ = tree_flatten(output_obj_grad) optimizer.backward_by_grad( tensor=output_obj_, - grad=output_obj_grad, + grad=output_obj_grad_, inputs=list(model_chunk.parameters()), retain_graph=False, ) @@ -861,11 +864,11 @@ def forward_backward_step( if self.forward_only: result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) + self._free_buffers() else: result = self.run_forward_backward( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) - - self.assert_buffer_empty() + self.assert_buffer_empty() return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ce1e869c892c..8ad51610bfa8 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -136,7 +136,6 @@ def criterion(x, *args, **kwargs): idx = rank else: idx = world_size * 2 - rank - 1 - print(f"{i=}, {idx=}, {rank=}, {torch_model.layers[idx].weight.grad=}, {sharded_model[i].weight.grad=}") assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) @@ -154,14 +153,24 @@ def criterion(x, *args, **kwargs): assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) - # forward one step - torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output) - - pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) - if stage_manager.is_first_stage(ignore_chunk=True): - print(f"{torch_loss=}, {pp_ret['loss']}") - assert_close(torch_loss, pp_ret["loss"]) + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True + ) + if stage_manager.is_first_stage(ignore_chunk=True): + assert_close(torch_loss, pp_ret["loss"]) + + for layer in sharded_model: + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None + else: + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + torch.cuda.empty_cache() @pytest.mark.dist From f9f04e579fd75a1d68117d20f03dd9319d3b2831 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 11:28:52 +0000 Subject: [PATCH 20/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 8ee41ae753c8..7b9ca8dd726c 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -153,13 +153,13 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) if self.forward_only: self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 - # NOTE: disable metadata cache when batch size changes (not valid anymore) - # if self.batch_size != self.last_batch_size: - # self.enable_metadata_cache = False - # self.send_tensor_metadata = True - # self.send_grad_metadata = True - # self.tensor_metadata_recv = None - # self.grad_metadata_recv = None + # NOTE: disable metadata cache when batch size changes (not valid anymore) + # if self.batch_size != self.last_batch_size: + # self.enable_metadata_cache = False + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None self.last_batch_size = self.batch_size @@ -490,12 +490,14 @@ def backward_b_step( ctx = nullcontext() with ctx: - input_obj_, tree_spec = tree_flatten(input_obj) + input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: - output_obj_, _ = tree_flatten(output_obj) - output_obj_grad_, _ = tree_flatten(output_obj_grad) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, @@ -544,8 +546,10 @@ def backward_w_step( if output_obj_grad is None: optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: - output_obj_, _ = tree_flatten(output_obj) - output_obj_grad_, _ = tree_flatten(output_obj_grad) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, From 6cf3ebc86fb3b84986aa4997ec37902e8844ff96 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 10 Sep 2024 17:33:09 +0800 Subject: [PATCH 21/47] [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/interface/optimizer.py | 21 +- colossalai/pipeline/__init__.py | 3 +- colossalai/pipeline/schedule/__init__.py | 2 + colossalai/pipeline/schedule/v_schedule.py | 494 ++++++++++ .../pipeline/schedule/zero_bubble_pp.py | 844 ++++++++++++++++++ .../test_schedule/test_zerobubble_pp.py | 769 ++++++++++++++++ 7 files changed, 2131 insertions(+), 4 deletions(-) create mode 100644 colossalai/pipeline/schedule/v_schedule.py create mode 100644 colossalai/pipeline/schedule/zero_bubble_pp.py create mode 100644 tests/test_pipeline/test_schedule/test_zerobubble_pp.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b4b40020fb2d..1b3b765c2ff0 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1103,7 +1103,7 @@ def __init__( self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", + enable_interleave=(pp_style == "interleaved"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index 6cd74b3b4305..a236434a55d6 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -55,8 +55,25 @@ def backward(self, loss: Tensor, *args, **kwargs): """ loss.backward(*args, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - torch.autograd.backward(tensor, grad) + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): + """ + Performs a backward pass for dx or dw, + for dx, we only calculate dx = w*dy here + for dw, we only calculate dw = x*dy here + + Args: + tensor (Tensor): y or loss of current chunk; + grad_tensors (Tensor): dy of current chunk; + input_obj (Tensor): for dx, input_obj is x of current chunk; + for dw, input_obj is w of current chunk; + retain_graph (bool): default to be True, we retain graph in backward_b + """ + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def state_dict(self): """ diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 4754212c1914..5d44530e7edd 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,11 +1,12 @@ from .p2p import PipelineP2PCommunication -from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule +from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler from .stage_manager import PipelineStageManager __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", "PipelineP2PCommunication", "PipelineStageManager", ] diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py index 6845dc23753b..05dd24e8169e 100644 --- a/colossalai/pipeline/schedule/__init__.py +++ b/colossalai/pipeline/schedule/__init__.py @@ -1,9 +1,11 @@ from .base import PipelineSchedule from .interleaved_pp import InterleavedSchedule from .one_f_one_b import OneForwardOneBackwardSchedule +from .zero_bubble_pp import ZeroBubbleVPipeScheduler __all__ = [ "PipelineSchedule", "OneForwardOneBackwardSchedule", "InterleavedSchedule", + "ZeroBubbleVPipeScheduler", ] diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py new file mode 100644 index 000000000000..9eebebdea463 --- /dev/null +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -0,0 +1,494 @@ +# Refer from Zero Bubble Pipeline Parallelism. +# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism +# Paper: https://arxiv.org/abs/2401.10241 +# The following applies to all files unless otherwise noted: +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from collections import deque +from dataclasses import dataclass + + +@dataclass(eq=True, frozen=True) +class ScheduledNode: + type: str + chunk: int + stage: int + minibatch: int + start_time: int = 0 + completion_time: int = 0 + rollback: bool = False + + +class PipelineGraph(object): + """PipelineGraph""" + + def __init__( + self, + n_stage, + n_micro, + f_cost, + b_cost, + w_cost, + c_cost, + f_mem, + b_mem, + w_mem, + max_mem=None, + ): + self.n_node = 6 * n_stage * n_micro + self.n_stage = n_stage + self.n_micro = n_micro + self.f_cost = f_cost + self.b_cost = b_cost + self.w_cost = w_cost + self.c_cost = c_cost + self.f_mem = f_mem + self.b_mem = b_mem + self.w_mem = w_mem + self.fbw_cost = [f_cost, b_cost, w_cost] + self.fbw_mem = [f_mem, b_mem, w_mem] + self.max_mem = max_mem or f_mem * self.n_stage * 2 + + def get_id(self, cat, chunk, stage, micro): + return ( + cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro + ) + + def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None): + count = [] + for i in range(self.n_stage): + count.append([0] * 6) + + end_time = [-1] * self.n_node + cur_time = [0] * self.n_stage + mem = [0] * self.n_stage + stage_bubble = [0] * self.n_stage + pending_w = [deque() for _ in range(self.n_stage)] + schedule = [[] for _ in range(self.n_stage)] + stage_str = [" " * i for i in range(self.n_stage)] + + if approved_bubble is None: + approved_bubble = [-1] * self.n_stage + max_approved_bubble = max(approved_bubble) + + def get_max_stage_bubble(stage=-1): + max_stage_bubble = 0 + for bb in stage_bubble: + max_stage_bubble = max(max_stage_bubble, bb) + if stage >= 0: + max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage]) + return max_stage_bubble + + def put_w(stage): + assert len(pending_w[stage]) > 0 + _, chunk_, _ = pending_w[stage].popleft() + put(2, chunk_, stage) + + def put(cat, chunk, stage, assert_cnt=True): + _tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat] + _cnt = count[stage][cat * 2 + chunk] + # assert _cnt < self.n_micro + if _cnt >= self.n_micro: + if not assert_cnt: + stage_str[stage] += " " + cur_time[stage] = _tmp # TODO + return + assert False + assert mem[stage] + self.fbw_mem[cat] <= self.max_mem + stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1))) + if cat > 0 or chunk > 0: + last_id = cat * 2 + chunk - 1 + if cat < 2: + # if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0: + # print(cat, chunk, stage, _cnt) + # self.print_details(end_time) + assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0 + else: + assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0 + if chunk == 1 and cat < 2: + if stage < self.n_stage - 1: + _fa_id = self.get_id(cat, chunk, stage + 1, _cnt) + assert end_time[_fa_id] >= 0 + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + if chunk == 0 and cat < 2: + if stage > 0: + _fa_id = self.get_id(cat, chunk, stage - 1, _cnt) + # if end_time[_fa_id] < 0: + # print(cat, chunk, stage, _cnt) + # self.print_details(end_time) + assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}" + _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat]) + _id = self.get_id(cat, chunk, stage, _cnt) + if count[stage][0] > 0: + stage_bubble[stage] += _tmp - _no_bubble + end_time[_id] = _tmp + cur_time[stage] = _tmp + mem[stage] += self.fbw_mem[cat] + # noinspection PyTypeChecker + schedule[stage].append((cat, chunk, _cnt)) + if cat == 1: + pending_w[stage].append((2, chunk, _cnt)) + count[stage][cat * 2 + chunk] += 1 + + # for _ in range(2 * self.n_stage): + # for i in range(self.n_stage): + # if count[i][1] >= count[i][0]: + # put(0, 0, i, assert_cnt=False) + # continue + # if i == self.n_stage - 1: + # put(0, 1, i, assert_cnt=False) + # continue + # fa_id = self.get_id(0, 1, i + 1, count[i][1]) + # if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO + # put(0, 1, i, assert_cnt=False) + # else: + # put(0, 0, i, assert_cnt=False) + + for i in range(self.n_stage): + put(0, 0, i) + for i in range(self.n_stage - 1, -1, -1): + if i == self.n_stage - 1: + put(0, 1, i) + continue + tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost + while ( + mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem + and cur_time[i] + self.fbw_cost[0] <= tmp + and count[i][0] < self.n_micro + ): + for j in range(i + 1): + put(0, 0, j) + put(0, 1, i) + iter_chunk_ = 0 + end_tmp = 0 + for i in range(self.n_stage): + if i == 0: + end_tmp = cur_time[0] + self.fbw_cost[1] + continue + tmp = end_tmp + self.c_cost + while ( + count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1] + or count[i][1] <= count[i - 1][1] < self.n_micro + ): + for j in range(self.n_stage - 1, i - 1, -1): + if count[j][iter_chunk_] < self.n_micro: + put(0, iter_chunk_, j) + iter_chunk_ = 1 - iter_chunk_ + # while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp: + # if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]: + # break + # for j in range(self.n_stage - 1, i - 1, -1): + # if count[j][iter_chunk_] < self.n_micro: + # put(0, iter_chunk_, j) + # iter_chunk_ = 1 - iter_chunk_ + # end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1] + + # init_bubble = get_max_stage_bubble() + # print(stage_bubble) + for _ in range(2 * self.n_micro): + # check mem before putting b + for i in range(self.n_stage): + while mem[i] + self.fbw_mem[1] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + b0_ranks, b1_ranks = [], [] + for i in range(self.n_stage): + if count[i][3] >= count[i][2]: + b0_ranks.append(i) + elif i == self.n_stage - 1: + b1_ranks.append(i) + else: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro: + b1_ranks.append(i) + else: + b0_ranks.append(i) + b_ranks = [] + # put b1 + for i in reversed(b1_ranks): + b_ranks.append((i, 1)) + # put b0 + for i in b0_ranks: + b_ranks.append((i, 0)) + for i, _chunk_ in b_ranks: + fa_id = -1 + if _chunk_ == 1 and i < self.n_stage - 1: + fa_id = self.get_id(1, 1, i + 1, count[i][3]) + if _chunk_ == 0 and i > 0: + fa_id = self.get_id(1, 0, i - 1, count[i][2]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if _chunk_ == 1: + put_w(i) + elif fill_b: + put_w(i) + put(1, _chunk_, i) + + # put f + for i in range(self.n_stage): + if count[i][1] >= self.n_micro: + continue + put_item = None + if count[i][1] >= count[i][0]: + put_item = 0 + elif i == self.n_stage - 1: + put_item = 1 + else: + if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0: + put_item = 1 + elif count[i][0] < self.n_micro: + if i == 0: + put_item = 0 + elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0: + put_item = 0 + if put_item is None: + continue + # check mem before putting f + while mem[i] + self.fbw_mem[0] > self.max_mem: + assert len(pending_w[i]) > 0 + put_w(i) + fa_id = -1 + if put_item == 0 and i > 0: + fa_id = self.get_id(0, 0, i - 1, count[i][0]) + if put_item == 1 and i < self.n_stage - 1: + fa_id = self.get_id(0, 1, i + 1, count[i][1]) + while ( + len(pending_w[i]) > 0 + and fa_id >= 0 + and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2] + ): + # fill the bubble + put_w(i) + if ( + len(pending_w[i]) > 0 + and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i] + ): + if fill_f: + put_w(i) + put(0, put_item, i) + + for i in range(self.n_stage): + while len(pending_w[i]) > 0: + put_w(i) + + # for i in range(self.n_stage): + # print(stage_str[i]) + + max_bubble = get_max_stage_bubble() + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + max_bubble / expected_time + # print("%6.4f" % bubble_rate, "->", stage_bubble) + if max_approved_bubble < 0 or max_bubble < max_approved_bubble: + _schedule, _end_time, _max_bubble = self.try_v_schedule( + fill_f=fill_f, + fill_b=fill_b, + approved_bubble=stage_bubble, + ) + if _max_bubble < max_bubble: + return _schedule, _end_time, _max_bubble + # print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \ + # (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble) + return schedule, end_time, max_bubble + + def print_details(self, end_time, print_scaling=1): + for stage in range(self.n_stage): + stage_str = ["."] * int(max(end_time) / print_scaling) + for _cat in range(3): + for _chunk in range(2): + for _micro in range(self.n_micro): + _id = self.get_id(_cat, _chunk, stage, _micro) + if end_time[_id] < 0: + continue + end = int(end_time[_id] / print_scaling) + start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling) + for j in range(start, end): + if j == start or j == end - 1: + stage_str[j] = "FfBbWw"[_cat * 2 + _chunk] + elif j == start + 1: + if _micro >= 10: + stage_str[j] = str(_micro // 10) + else: + stage_str[j] = str(_micro) + elif j == start + 2 and _micro >= 10: + stage_str[j] = str(_micro % 10) + else: + stage_str[j] = "-" + _str = "" + for _c in stage_str: + _str += _c + print(_str) + + def get_v_schedule(self, only_run_time=False): + schedule, end_time, max_bubble = None, None, None + expected_time = sum(self.fbw_cost) * self.n_micro * 2 + for fill_b in [True, False]: + for fill_f in [True, False]: + _schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f) + # print("") + if max_bubble is None or _max_bubble < max_bubble: + max_bubble = _max_bubble + schedule = _schedule + end_time = _end_time + if only_run_time: + return max_bubble + expected_time + # self.print_details(end_time, print_scaling=1) + max_bubble / (expected_time + max_bubble) + # print("%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % \ + # (self.n_stage, self.n_micro, *self.fbw_cost, self.c_cost, self.max_mem // self.f_mem, bubble_rate)) + local_order = [[] for _ in range(self.n_stage)] + comm_id = {} + comm_id_counter = 0 + post_validation_time = 0 + for i in range(self.n_stage - 1, -1, -1): + pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1) + post_validation_time = max( + post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost + ) + # post_validation_time = 0 + # print(i, pv_id, post_validation_time) + for it in ["RECV_", "SEND_", ""]: + if i == 0 and it == "SEND_": + continue + if i == self.n_stage - 1 and it == "RECV_": + continue + # stage_ = i - 1 if it == "RECV_" else i + stage_ = i + local_order[stage_].append( + ScheduledNode( + type=it + "POST_VALIDATION", + chunk=0, + stage=stage_, + minibatch=0, + start_time=post_validation_time, + completion_time=post_validation_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + comm_id_counter += 1 + for i in range(self.n_stage): + for _cat_, _chunk_, _micro_ in schedule[i]: + complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)] + local_order[i].append( + ScheduledNode( + type="FBW"[_cat_], + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=i, + minibatch=_micro_, + start_time=complete_time - self.fbw_cost[_cat_], + completion_time=complete_time, + ) + ) + if _cat_ == 2: # no communication for W + continue + cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD" + + def communicate(send_recv, stage_): + # noinspection PyTypeChecker + local_order[stage_].append( + ScheduledNode( + type=send_recv + cat_str, + chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_, + stage=stage_, + minibatch=_micro_, + start_time=complete_time, + completion_time=complete_time, + ) + ) + comm_id[local_order[stage_][-1]] = comm_id_counter + + if _chunk_ == 1 and i > 0: + communicate("SEND_", i) + communicate("RECV_", i - 1) + if _chunk_ == 0 and i < self.n_stage - 1: + communicate("SEND_", i) + communicate("RECV_", i + 1) + comm_id_counter += 1 + for rank in range(self.n_stage): + # For nodes with the same timestamp on the same stage, communication will be prioritized. + def even_breaker(x: ScheduledNode): + # Compute nodes are always delayed. + if x.type in ["F", "B", "W"]: + return comm_id_counter + # For comm nodes, order by their unique comm id + return comm_id[x] + + local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x)))) + # If a recv with intersects with previous computation, reorder them so that recv + # is executed before computation and hence can be overlapped. + for i in range(len(local_order[rank])): + if ( + i > 0 + and local_order[rank][i - 1].type in {"F", "B", "W"} + and local_order[rank][i].type.startswith("RECV") + and "POST_VALIDATION" not in local_order[rank][i].type + and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time + ): + local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i] + + local_order_with_rollback = [[] for _ in range(self.n_stage)] + for rank in range(self.n_stage): + rollback_comm = set() + if rank > 0: + for node in local_order[rank - 1]: + if node.type == "POST_VALIDATION": + break + if node.type == "SEND_FORWARD": + assert node.chunk == 0 + rollback_comm.add(node.minibatch) + for node in local_order[rank]: + if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm: + rollback = True + rollback_comm.remove(node.minibatch) + else: + rollback = False + local_order_with_rollback[rank].append( + ScheduledNode( + type=node.type, + chunk=node.chunk, + stage=node.stage, + minibatch=node.minibatch, + start_time=node.start_time, + completion_time=node.completion_time, + rollback=rollback, + ) + ) + assert len(rollback_comm) == 0 + # for node in local_order_with_rollback[rank]: + # print(f"Rank {rank} Node info {node}") + # print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=", ") + # print() + + return local_order_with_rollback diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py new file mode 100644 index 000000000000..c1c4f13c68c2 --- /dev/null +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -0,0 +1,844 @@ +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.cuda +from torch.nn import Module, ModuleList +from torch.utils._pytree import tree_map + +from colossalai.accelerator import get_accelerator +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.schedule.v_schedule import ScheduledNode +from colossalai.pipeline.stage_manager import PipelineStageManager + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device +from .base import PipelineSchedule + +AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} + + +def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: + if wait_handles is not None: + for req in wait_handles: + req.wait() + + +def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): + """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + if (out is None) or (not deallocate_pipeline_outputs): + return + assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, "counter-productive to free a view of another tensor." + # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + out.data.untyped_storage().resize_(0) + + +class ZeroBubbleVPipeScheduler(PipelineSchedule): + def __init__( + self, + stage_manager: PipelineStageManager, + schedule: List[ScheduledNode], + num_model_chunks: int, + num_microbatch: Optional[int] = None, + microbatch_size: Optional[int] = None, + enable_metadata_cache: bool = True, + overlap_p2p: bool = True, + ): + super().__init__(stage_manager) + # batch info + self.num_microbatch = num_microbatch + self.microbatch_size = microbatch_size + self.num_model_chunks = num_model_chunks + self.batch: Any + self.batch_size: int + self.last_batch_size: Optional[int] = None + self.microbatch_offset: List[int] + + self.schedules = schedule + # TODO: optim post valid + self.do_post_validation = False + + # P2PMeta cache + # self.enable_metadata_cache = enable_metadata_cache + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None + + # P2P communication + self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) + + # init communication map + self.communication_map = { + "SEND_FORWARD": self.send_forward, + "RECV_FORWARD": self.recv_forward, + "SEND_BACKWARD": self.send_backward, + "RECV_BACKWARD": self.recv_backward, + } + + # init buffer + self._free_buffers() + + def _free_buffers(self): + # free local buffer + # two dim array, first dim is the model chunk, second dim is the microbatch queue + + # x & y buffer for schedule b + self.input_tensors = [[], []] + self.output_tensors = [[], []] + + # y & dy buffer for schedule w + self.output_tensors_dw = [[], []] + self.output_tensors_grad_dw = [[], []] + + # buffer for communication + self.send_forward_buffer = [[], []] + self.recv_forward_buffer = [[], []] + self.send_backward_buffer = [[], []] + self.recv_backward_buffer = [[], []] + + # y buffer for local send fwd + self.local_send_forward_buffer = [] + # dy buffer for local send bwd + self.local_send_backward_buffer = [] + + def assert_buffer_empty(self): + # assert buuffer is empty at end + assert len(self.input_tensors[0]) == 0 + assert len(self.input_tensors[1]) == 0 + assert len(self.output_tensors[0]) == 0 + assert len(self.output_tensors[1]) == 0 + assert len(self.output_tensors_dw[0]) == 0 + assert len(self.output_tensors_dw[1]) == 0 + assert len(self.output_tensors_grad_dw[0]) == 0 + assert len(self.output_tensors_grad_dw[1]) == 0 + assert len(self.send_forward_buffer[0]) == 0 + assert len(self.send_forward_buffer[1]) == 0 + assert len(self.recv_forward_buffer[0]) == 0 + assert len(self.recv_forward_buffer[1]) == 0 + assert len(self.send_backward_buffer[0]) == 0 + assert len(self.send_backward_buffer[1]) == 0 + assert len(self.recv_backward_buffer[0]) == 0 + assert len(self.recv_backward_buffer[1]) == 0 + assert len(self.local_send_forward_buffer) == 0 + assert len(self.local_send_backward_buffer) == 0 + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + self.batch = batch + self.batch_size = get_batch_size(batch) + + if self.microbatch_size is None: + assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" + self.microbatch_size = self.batch_size // self.num_microbatch + if self.num_microbatch is None: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatch = self.batch_size // self.microbatch_size + + if not self.forward_only: + assert self.last_batch_size is None or self.last_batch_size == self.batch_size + assert self.batch_size == self.microbatch_size * self.num_microbatch + + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" + + if self.forward_only: + self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + # NOTE: disable metadata cache when batch size changes (not valid anymore) + # if self.batch_size != self.last_batch_size: + # self.enable_metadata_cache = False + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None + + self.last_batch_size = self.batch_size + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + assert ( + microbatch_id < self.num_microbatch * self.num_model_chunks + ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" + microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not is_forward: + # Reverse order + model_chunk_id = self.num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 & is_first_stage + # do nothing; cause u are chunk 0 in first rank, u have no prev rank; + ################# + if self.stage_manager.is_first_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 0 & not is_first_stage + # Recv y from PREV_rank as input + ################# + else: + prev_rank = self.stage_manager.get_prev_rank() + input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) + self.recv_forward_buffer[model_chunk_id].append(input_tensor) + return input_tensor, wait_handles + + else: + ################ + # chunk = 1 & is_last_stage + # do nothing; cause u get y from local_send_forward_buffer in schedule f + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 1 & not is_last_stage + # recv y from NEXT_rank as input + ################ + else: + next_rank = self.stage_manager.get_next_rank() + input_tensor, wait_handles = self.comm.recv_forward(next_rank) + self.recv_forward_buffer[model_chunk_id].append(input_tensor) + return input_tensor, wait_handles + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + Any: The wait handles for the communication. + """ + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 & is_last_stage + # do nothing; Already get dy from local_send_backward_buffer in schedule b + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 0 & not is_last_stage + # Recv bwd from next stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) + self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + return output_tensor_grad, wait_handles + + else: + # bwd chunk1 is left V; + ################ + # chunk = 1 & is_first_stage + # do nothing; get loss from local + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return None, [] + + ################ + # chunk = 1 & not first stage + # recv_backward recv bwd from prev stage; + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) + self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + return output_tensor_grad, wait_handles + + def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: + """Sends the input tensor to the next stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the recipient of the tensor. + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + ################ + # chunk = 0 && is_last_stage + # do nothing; hold y on local_send_forward_buffer + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 0 && not is_last_stage + # self.comm.send_forward send y to NEXT stage + ################ + else: + next_rank = self.stage_manager.get_next_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) + return send_handles + + else: + ################ + # chunk = 1 && is_first_stage + # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 1 && not is_first_stage + # self.comm.send_forward send y to PREV stage + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_forward(output_tensor, prev_rank) + return send_handles + + def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: + """Sends the gradient tensor to the previous stage in pipeline. + For ZBV. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the recipient of the tensor + + Returns: + Any: The wait handles for the communication. + """ + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + if model_chunk_id == 0: + # bwd chunk0 is right V; + ################ + # chunk = 0 && is_first_stage + # do nothing; cause u are the first chunk in first stage; bwd end + ################ + if self.stage_manager.is_first_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 0 && not is_first_stage + # Send dx to PREV stage; + ################ + else: + prev_rank = self.stage_manager.get_prev_rank() + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) + return send_handles + + # bwd chunk1 is left V; + else: + ################ + # chunk = 1 && is_last_stage + # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; + ################ + if self.stage_manager.is_last_stage(ignore_chunk=True): + return [] + + ################ + # chunk = 1 && not is_last_stage + # Send dx to NEXT stage; + ################ + else: + next_rank = self.stage_manager.get_next_rank() + input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) + send_handles = self.comm.send_backward(input_tensor_grad, next_rank) + return send_handles + + def forward_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + input_obj (Optional[dict]): x; + criterion (Callable): loss function; + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + # Load input ids, attention mask and labels + # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. + # Only attention_mask from micro_batch is used + + with self.stage_manager.switch_model_chunk_id(model_chunk_id): + # fwd calculate + output_obj = model_chunk[model_chunk_id](input_obj) + # last layer in model + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + loss = criterion(output_obj) / self.num_microbatch + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_b_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: + """Backward dx step of the pipeline; we calculate "dx = w*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): x. + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Optional[dict]: dx. + """ + # calculate bwd b step ; only dx = w*dy; + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # loss backward; output_obj is loss; so output_obj_grad should be None + assert output_obj_grad is None + + optimizer.backward_by_grad( + tensor=output_obj, + grad=output_obj_grad, + inputs=input_obj, + retain_graph=True, + ) + return input_obj.grad + + def backward_w_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ): + """Backward dw step of the pipeline; we calculate "dw = x*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Nothing need to return; we only calculate dw then update w; + """ + # calculate bwd w step ; only dw = x*dy; + + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # loss backward; output_obj is loss + output_obj_grad = None + optimizer.backward_by_grad( + tensor=output_obj, + grad=output_obj_grad, + inputs=list(model_chunk[model_chunk_id].parameters()), + retain_graph=False, + ) + + def schedule_f( + self, + scheduled_node, + model_chunk: torch.nn.ModuleList, + model_chunk_id: int, + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ): + """A complete forward schedule; Include recv fwd --> cal fwd --> send fwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + criterion (Callable): loss function; + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Nothing. + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + # Step1: recv fwd + if model_chunk_id == 0: + # is first stage; get input from func param + if self.stage_manager.is_first_stage(ignore_chunk=True): + input_obj = micro_batch + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + else: + # is last stage; recv from local + if self.stage_manager.is_last_stage(ignore_chunk=True): + input_obj = self.local_send_forward_buffer.pop(0) + # not last stage; recv from next + else: + input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + + # Here, let input_obj.requires_grad_() + tree_map(torch.Tensor.requires_grad_, input_obj) + + # Step2: fwd step + output_obj = self.forward_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + input_obj=input_obj, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not detach bwd LOSS + pass + else: + detached_output_obj = output_obj.clone().detach() + + # Step3: send fwd + # add output to send_fwd_buffer + if model_chunk_id == 0: + # is last stage; send to local_send_forward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_forward_buffer.append(detached_output_obj) + else: + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + else: + # is first stage; end of fwd; append LOSS to local_send_backward_buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + else: + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + + # add input and output object for backward b + self.input_tensors[model_chunk_id].append(input_obj) + # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj + deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) + self.output_tensors[model_chunk_id].append(output_obj) + # add output object for backward w + self.output_tensors_dw[model_chunk_id].append(output_obj) + + def schedule_b( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + # input_obj: Optional[dict], + # output_obj: Union[dict, torch.Tensor], + # output_obj_grad: Optional[dict], + ): + """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + + # Step1: recv bwd + if model_chunk_id == 0: + # chunk0 is last stage; recv output_grad from local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + output_tensor_grad = self.local_send_backward_buffer.pop(0) + # chunk 0 not last stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + else: + # chunk1, is first stage; recv LOSS from local send bwd buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + output_tensor_grad = None + # chunk1, not first stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + + # get input and output object from buffer; + input_obj = self.input_tensors[model_chunk_id].pop(0) + output_obj = self.output_tensors[model_chunk_id].pop(0) + + # save output_tensor_grad for dw + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # we save loss here + self.output_tensors_grad_dw[model_chunk_id].append(output_obj) + else: + # we save output_tensor_grad here + self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) + + # _wait_p2p(recv_bwd_handles) + # Step2: bwd step + input_object_grad = self.backward_b_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + optimizer=optimizer, + input_obj=input_obj, + output_obj=output_obj, + output_obj_grad=output_tensor_grad, + ) + + # Step3: send bwd + if model_chunk_id == 0: + # do nothing; end of bwd; + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + # save input_object_grad to send_backward_buffer + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + else: + # send to local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_backward_buffer.append(input_object_grad) + # send to next + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + + def schedule_w( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + ): + """A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w); + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + + # get y & dy from buffer + output_obj = self.output_tensors_dw[model_chunk_id].pop(0) + output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) + + self.backward_w_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + optimizer=optimizer, + output_obj=output_obj, + output_obj_grad=output_obj_grad, + ) + + def run_forward_only( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + assert self.forward_only + + # prepare batch + self.load_batch(data_iter) + + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + + # while we still have schedules_node in self.schedules + for it in range(len(self.schedules)): + scheduled_node = self.schedules[it] + + if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: + # communication + communication_func = self.communication_map[scheduled_node.type] + communication_func(scheduled_node.chunk) + if scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + + def run_forward_backward( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + """ + Runs Zerobubble schedule, with communication between pipeline stages. + """ + # prepare batch + self.load_batch(data_iter) + + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + + # while we still have schedules_node in self.schedules + schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + for it in range(len(schedule)): + scheduled_node = schedule[it] + if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: + # communication + communication_func = self.communication_map[scheduled_node.type] + communication_func(scheduled_node.chunk) + + if scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + elif scheduled_node.type == "B": + self.schedule_b( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, + ) + elif scheduled_node.type == "W": + self.schedule_w( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, + ) + + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + + def forward_backward_step( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """ + Args: + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + self.forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert self.forward_only, "Optimizer should be passed when doing backward." + + if self.forward_only: + result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) + else: + result = self.run_forward_backward( + model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + self.assert_buffer_empty() + + return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py new file mode 100644 index 000000000000..825c192d8fd5 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -0,0 +1,769 @@ +from copy import deepcopy +from typing import Tuple + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +class MlpModel(nn.Module): + def __init__(self, in_dim, out_dim, num_layers): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +# 1) Test manual v_schedule with multiple microbatch +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_iter_input(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) + + # schedule list + zbv_schedule = [ + # stage 0 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 1 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 2 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3), + ], + # stage 3 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3), + ], + ] + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=pp_size, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = 4 + num_layers = 8 + in_dim = out_dim = 8 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) + + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + optimizer_base.step() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + +# 2) add optimizer base 1) +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 8, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_vschedule_with_optim(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) + + h, a, s = 4096, 32, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatch, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=num_model_chunk, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # init loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = test_config["batch_size"] + num_layers = 8 + assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" + in_dim = out_dim = 4096 + before_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) + + after_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() + + after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 + + # assert memory + if rank != 0: + # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + # output hid_dim * hid_dim * 4(fp32) / 1024**3 + # optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}") + assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3) + else: + # rank0 will also hold output; + print( + f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + ) + assert round((after_pp_step_memory - after_init_memory), 5) <= round( + (in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 + ) + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + optimizer_base.step() + + ########################## + # assert loss & output + ########################## + # only chunk 1 stage 0 hold loss and output + if rank == 0: + assert_close(result["loss"], loss_base) + assert_close(result["outputs"], output_base) + + # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + ########################## + # assert optim state + ########################## + optim_base_state = optimizer_base.state_dict()["state"] + optim_pp_state = optimizer_pp.state_dict()["state"] + optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] + optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] + # if rank == 0: + # print(f"optim_base_state {optim_base_state}") + + # assert param group + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): + if key_base == key_pp: + if key_base != "params": + assert val_base == val_pp + else: + # BUG: + # param_base: [0, 1, 2, 3, 4, 5, 6, 7]; + # params pp: [0, 1]; + assert val_base[:2] == val_pp + + # assert state + assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) + assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"]) + + +# TODO:4) support Hybrid base 3) +def run_with_hybridplugin(test_config): + pass + + +# TODO:5) support MoEHybrid base 3) +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_with_moehybridplugin(test_config): + model_zoo.get_sub_registry("transformers_bert") + test_config["use_lazy_init"] = False + test_config["initial_scale"] = 2**16 + model_list = [ + "transformers_bert", + ] + + +# TODO:6) support booster & Hybrid base 4) + +# TODO:7) support booster & MoEHybrid base 4) + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + # run_fwd_bwd_iter_input() + run_fwd_bwd_vschedule_with_optim() + # run_with_moehybridplugin() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn( + run_dist, + nprocs=4, + ) + + +if __name__ == "__main__": + test_pp() From 3dd5d59b93f7fca988f049365bfb5e1e86cb0cb3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 03:48:50 +0000 Subject: [PATCH 22/47] hybrid support zbv --- .../naive_amp/mixed_precision_mixin/base.py | 2 +- .../naive_amp/mixed_precision_optimizer.py | 4 +- .../booster/plugin/hybrid_parallel_plugin.py | 46 +++++--- .../pipeline/schedule/zero_bubble_pp.py | 105 +++++++++++++----- colossalai/pipeline/stage_manager.py | 17 ++- colossalai/shardformer/modeling/llama.py | 1 + colossalai/shardformer/policies/llama.py | 12 +- colossalai/zero/gemini/gemini_ddp.py | 2 +- colossalai/zero/gemini/gemini_optimizer.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 8 +- tests/test_shardformer/test_model/_utils.py | 13 ++- .../test_model/test_shard_llama.py | 53 +++++++-- 12 files changed, 196 insertions(+), 71 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index fc7e0b74179a..b2ba47f6762d 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -43,7 +43,7 @@ def zero_grad(self): dtype: torch.dtype @abstractmethod - def pre_backward(self, loss: Tensor) -> Tensor: + def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor: """Called before backward. Args: diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 9e07bdebf8fa..700f80336cf0 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -89,9 +89,9 @@ def backward(self, loss: Tensor, *args, **kwargs): loss = self.mixed_precision.pre_backward(loss) loss.backward(*args, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) - tensor.backward(grad) + tensor.backward(grad, inputs=inputs, retain_graph=retain_graph) def zero_grad(self, *args, **kwargs): for p in self.working_to_master_map.keys(): diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1b3b765c2ff0..1323d14b320c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,7 +28,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -39,6 +39,7 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from .pp_plugin_base import PipelinePluginBase @@ -315,7 +316,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -332,7 +333,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -538,7 +539,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -554,7 +555,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): None """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -768,7 +769,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: else: return - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): """ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -784,7 +785,7 @@ def backward(self, loss, retain_graph=False): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) + super().backward(loss, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -793,7 +794,7 @@ def backward(self, loss, retain_graph=False): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -809,7 +810,7 @@ def backward_by_grad(self, tensor, grad): None """ # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1013,6 +1014,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -1029,6 +1031,7 @@ def __init__( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert not pp_style == "zbv" or scheduler_nodes is not None, f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1088,12 +1091,13 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert pp_style in ["interleaved", "zbv"] or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2, "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -1103,14 +1107,15 @@ def __init__( self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=(pp_style == "interleaved"), + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), + use_zbv=(pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -1119,12 +1124,20 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + self.scheduler = ZeroBubbleVPipeScheduler( + stage_manager=self.stage_manager, + schedule = scheduler_nodes, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": @@ -1236,7 +1249,6 @@ def configure( # Replace with distributed implementation if exists optimizer = cast_to_distributed(optimizer) - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: self.logger.warning( "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", @@ -1352,7 +1364,7 @@ def execute_pipeline( ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx, model._wait_all_gather(): - outputs = self.schedule.forward_backward_step( + outputs = self.scheduler.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c1c4f13c68c2..03196c48c311 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,7 +12,7 @@ from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device, model_forward from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -33,10 +33,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): """ if (out is None) or (not deallocate_pipeline_outputs): return + print(f"{out=}") assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." - # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) - out.data.untyped_storage().resize_(0) + out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + # out.data.untyped_storage().resize_(0) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -59,10 +60,9 @@ def __init__( self.batch_size: int self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] - - self.schedules = schedule # TODO: optim post valid self.do_post_validation = False + self.schedules = schedule # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache @@ -432,13 +432,17 @@ def forward_step( # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. # Only attention_mask from micro_batch is used + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate - output_obj = model_chunk[model_chunk_id](input_obj) + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + # output_obj = model_chunk[model_chunk_id](input_obj) # last layer in model - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - loss = criterion(output_obj) / self.num_microbatch + if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: @@ -472,19 +476,50 @@ def backward_b_step( # calculate bwd b step ; only dx = w*dy; # Retain the grad on the input_obj. - tree_map(retain_grad, input_obj) + if input_obj is None: + return None + else: + tree_map(retain_grad, input_obj) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - optimizer.backward_by_grad( - tensor=output_obj, - grad=output_obj_grad, - inputs=input_obj, - retain_graph=True, - ) - return input_obj.grad + print(f"{input_obj=}") + print(f"{output_obj=}") + + if "hidden_states" in input_obj.keys(): + input_obj_ = input_obj["hidden_states"] + else: + input_obj_ = input_obj["input_ids"] + + if output_obj_grad is None: + optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) + # optimizer.backward_by_grad( + # tensor=output_obj, + # grad=output_obj_grad, + # inputs=input_obj, + # retain_graph=True, + # ) + else: + output_obj_ = output_obj["hidden_states"] + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad, + inputs=input_obj_, + retain_graph=True, + ) + + # if "backward_tensor_keys" not in output_obj: + # for k, grad in output_obj_grad.items(): + # optimizer.backward_by_grad(output_obj[k], grad, inputs=input_obj_, retain_graph=True) + # else: + # for k, grad in output_obj_grad.items(): + # output_obj[k].grad = grad + # for k in output_obj["backward_tensor_keys"]: + # tensor_to_backward = output_obj[k] + # optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad, inputs=input_obj_, retain_graph=True) + return input_obj_.grad def backward_w_step( self, @@ -511,12 +546,20 @@ def backward_w_step( if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss output_obj_grad = None - optimizer.backward_by_grad( - tensor=output_obj, - grad=output_obj_grad, - inputs=list(model_chunk[model_chunk_id].parameters()), - retain_graph=False, - ) + + if output_obj_grad is None: + print(optimizer) + # optimizer.backward(output_obj, inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=True) + optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=True) + else: + output_obj_ = output_obj["hidden_states"] + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad, + # inputs=list(model_chunk[model_chunk_id].parameters()), + inputs=list(model_chunk.parameters()), + retain_graph=False, + ) def schedule_f( self, @@ -540,12 +583,11 @@ def schedule_f( Returns: Nothing. """ - micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # Step1: recv fwd if model_chunk_id == 0: # is first stage; get input from func param if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = micro_batch + input_obj = None else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: @@ -557,7 +599,9 @@ def schedule_f( input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Here, let input_obj.requires_grad_() - tree_map(torch.Tensor.requires_grad_, input_obj) + # add input and output object for backward b + if input_obj is not None: + tree_map(torch.Tensor.requires_grad_, input_obj) # Step2: fwd step output_obj = self.forward_step( @@ -572,7 +616,8 @@ def schedule_f( # We should not detach bwd LOSS pass else: - detached_output_obj = output_obj.clone().detach() + # detached_output_obj = output_obj.clone().detach() + detached_output_obj = tree_map(detach, output_obj) # Step3: send fwd # add output to send_fwd_buffer @@ -589,10 +634,10 @@ def schedule_f( else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) - # add input and output object for backward b - self.input_tensors[model_chunk_id].append(input_obj) # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) + # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) + # deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) + self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) @@ -718,7 +763,7 @@ def run_forward_only( accum_loss = None # reset accum loss at fwd end; - if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + if return_loss and self.stage_manager.is_last_stage(): accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 354f110f0b0d..47bc73db69b8 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -26,6 +26,7 @@ def __init__( pg_mesh: ProcessGroupMesh, pipeline_axis: int, enable_interleave: bool = False, + use_zbv: bool = False, num_model_chunks: int = 1, num_layers_per_stage: Optional[List[int]] = None, ) -> None: @@ -49,6 +50,7 @@ def __init__( next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.is_interleave = enable_interleave + self.use_zbv = use_zbv # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks # for shardformer, hold stage indices of model @@ -85,6 +87,15 @@ def get_stage_index( num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) stage_indices = [] + if self.use_zbv: + stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]]) + stage_indices.append( + [ + num_layers_per_stage_accumulated[2 * num_stages - stage - 1], + num_layers_per_stage_accumulated[2 * num_stages - stage], + ] + ) + return stage_indices for model_chunk in range(num_model_chunks): start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] @@ -124,7 +135,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool: if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + # use zero bubble pipeline + if self.use_zbv: + return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1 + else: + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index af610500a8eb..09876ec56f8b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -225,6 +225,7 @@ def llama_model_forward( all_self_attns += (layer_outputs[1],) if stage_manager.is_last_stage(): + print(f"{hidden_states=}") hidden_states = self.norm(hidden_states) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 67e6e92d1d36..074734f761cf 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.norm) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.norm) else: @@ -351,8 +353,10 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) + elif stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: @@ -404,7 +408,9 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.score) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.score) return held_layers diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 80b2c7961e29..d2754cbd965b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor): loss.backward() self._post_backward() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") @staticmethod diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index fdf2a497626f..2283afd1eb42 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,12 +298,12 @@ def backward(self, loss: torch.Tensor): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - grad = self.mix_precision_mixin.pre_backward_by_grad(grad) + grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 51d7d1eaaa33..2587e26944f8 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -408,7 +408,7 @@ def _add_to_bucket(self, param, group_id): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" @@ -416,7 +416,7 @@ def backward(self, loss, retain_graph=False): if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - loss.backward(retain_graph=retain_graph) + loss.backward(inputs=inputs, retain_graph=retain_graph) if not self.require_grad_sync: return @@ -427,14 +427,14 @@ def backward(self, loss, retain_graph=False): if self._overlap_communication: get_accelerator().synchronize() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + torch.autograd.backward(tensor, grad, inputs=inputs, retain_graph=retain_graph,) if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 9ad84341ac9e..2c7a3866f784 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin( sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) criterion = loss_fn - plugin = pluggin_cls(**test_config) booster = Booster(plugin=plugin) @@ -311,8 +310,16 @@ def check_output_hidden_state( ): org_hidden_state = org_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): - sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + if stage_manager: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state + elif stage_manager.is_last_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state else: sharded_hidden_state = sharded_output.last_hidden_state diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3c66f609787a..e4b28f7f9929 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -22,6 +22,7 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) +from colossalai.pipeline.schedule.v_schedule import PipelineGraph os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" @@ -31,6 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config ) + print(f"{sharded_model=}") if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() sharded_model.unwrap().gradient_checkpointing_enable() @@ -112,12 +114,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): + check_flag = False + if stage_manager is None: + check_flag = True + else: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + check_flag = True + elif stage_manager.is_last_stage(ignore_chunk=True): + check_flag = True + if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state( org_output, @@ -282,10 +292,39 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "zbv", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 0, + "initial_scale": 1, + "enable_gradient_checkpointing": False, + "parallel_output": False, + }, ], ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + if test_config.get("pp_style", None) == "zbv": + mem_f = 34 * 32 + 5 * 4 * 16 + mem_w = - 32 * 32 + mem_b = - mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage = test_config["pp_size"], + n_micro=test_config["num_microbatches"], + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue @@ -372,11 +411,11 @@ def test_llama(): spawn(check_llama, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama_3d(): - spawn(check_llama_3d, 8) +# @pytest.mark.largedist +# @rerun_if_address_is_in_use() +# @clear_cache_before_run() +# def test_llama_3d(): +# spawn(check_llama_3d, 8) if __name__ == "__main__": From fee18d03d185f4cd43d9e3402630f2f3ee682420 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 03:55:12 +0000 Subject: [PATCH 23/47] fix fix --- .../pipeline/schedule/zero_bubble_pp.py | 33 ++----------------- colossalai/shardformer/modeling/llama.py | 1 - .../test_model/test_shard_llama.py | 10 +++--- 3 files changed, 8 insertions(+), 36 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 03196c48c311..9cdb3f5d1b1f 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -33,11 +33,9 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): """ if (out is None) or (not deallocate_pipeline_outputs): return - print(f"{out=}") assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) - # out.data.untyped_storage().resize_(0) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -439,7 +437,6 @@ def forward_step( internal_inputs = {} if input_obj is None else input_obj internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - # output_obj = model_chunk[model_chunk_id](input_obj) # last layer in model if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatch @@ -484,9 +481,6 @@ def backward_b_step( if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - - print(f"{input_obj=}") - print(f"{output_obj=}") if "hidden_states" in input_obj.keys(): input_obj_ = input_obj["hidden_states"] @@ -495,12 +489,6 @@ def backward_b_step( if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) - # optimizer.backward_by_grad( - # tensor=output_obj, - # grad=output_obj_grad, - # inputs=input_obj, - # retain_graph=True, - # ) else: output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( @@ -509,16 +497,6 @@ def backward_b_step( inputs=input_obj_, retain_graph=True, ) - - # if "backward_tensor_keys" not in output_obj: - # for k, grad in output_obj_grad.items(): - # optimizer.backward_by_grad(output_obj[k], grad, inputs=input_obj_, retain_graph=True) - # else: - # for k, grad in output_obj_grad.items(): - # output_obj[k].grad = grad - # for k in output_obj["backward_tensor_keys"]: - # tensor_to_backward = output_obj[k] - # optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad, inputs=input_obj_, retain_graph=True) return input_obj_.grad def backward_w_step( @@ -548,15 +526,12 @@ def backward_w_step( output_obj_grad = None if output_obj_grad is None: - print(optimizer) - # optimizer.backward(output_obj, inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=True) optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=True) else: output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad, - # inputs=list(model_chunk[model_chunk_id].parameters()), inputs=list(model_chunk.parameters()), retain_graph=False, ) @@ -616,7 +591,6 @@ def schedule_f( # We should not detach bwd LOSS pass else: - # detached_output_obj = output_obj.clone().detach() detached_output_obj = tree_map(detach, output_obj) # Step3: send fwd @@ -634,10 +608,10 @@ def schedule_f( else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + # add input and output object for backward b + self.input_tensors[model_chunk_id].append(input_obj) # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) - # deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) - self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) @@ -690,7 +664,6 @@ def schedule_b( # we save output_tensor_grad here self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # _wait_p2p(recv_bwd_handles) # Step2: bwd step input_object_grad = self.backward_b_step( model_chunk=model_chunk, @@ -763,7 +736,7 @@ def run_forward_only( accum_loss = None # reset accum loss at fwd end; - if return_loss and self.stage_manager.is_last_stage(): + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 09876ec56f8b..af610500a8eb 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -225,7 +225,6 @@ def llama_model_forward( all_self_attns += (layer_outputs[1],) if stage_manager.is_last_stage(): - print(f"{hidden_states=}") hidden_states = self.norm(hidden_states) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index e4b28f7f9929..6fe7e0d7a920 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -411,11 +411,11 @@ def test_llama(): spawn(check_llama, 4) -# @pytest.mark.largedist -# @rerun_if_address_is_in_use() -# @clear_cache_before_run() -# def test_llama_3d(): -# spawn(check_llama_3d, 8) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) if __name__ == "__main__": From b93d00820ea1d27a363ee70aaf0b729b821d561b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 10:55:44 +0000 Subject: [PATCH 24/47] fix --- .../naive_amp/mixed_precision_optimizer.py | 4 ++-- .../booster/mixed_precision/fp16_torch.py | 4 ++-- .../booster/plugin/hybrid_parallel_plugin.py | 10 ++++++---- colossalai/interface/optimizer.py | 4 ++-- .../pipeline/schedule/zero_bubble_pp.py | 19 ++++++------------- .../test_model/test_shard_llama.py | 5 ++--- 6 files changed, 20 insertions(+), 26 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 700f80336cf0..121c92011cc5 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -85,9 +85,9 @@ def __init__( master_params.append(master_p) group["params"] = master_params - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): loss = self.mixed_precision.pre_backward(loss) - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index c757a878d97a..a85d9f808546 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -46,9 +46,9 @@ def __init__( growth_interval=growth_interval, ) - def backward(self, loss: Tensor, *args, **kwargs) -> None: + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None: scaled_loss = self.scale_loss(loss) - scaled_loss.backward(*args, **kwargs) + scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def step(self, *args, **kwargs) -> Optional[float]: out = self.scaler.step(self.optim, *args, **kwargs) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1323d14b320c..791281506bff 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -289,7 +289,7 @@ def __init__( self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 super().__init__(optim) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -307,7 +307,7 @@ def backward(self, loss: Tensor, *args, **kwargs): """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -513,7 +513,7 @@ def __init__( max_norm=max_norm, ) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -530,7 +530,7 @@ def backward(self, loss: Tensor, *args, **kwargs): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1104,6 +1104,8 @@ def __init__( assert ( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + if pp_style == "zbv": + self.logger.warning("""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""") self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index a236434a55d6..c8cf3ec21360 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs): """ self.optim.zero_grad(*args, **kwargs) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): """ Performs a backward pass on the loss. """ - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 9cdb3f5d1b1f..6c30b32dad38 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -526,7 +526,7 @@ def backward_w_step( output_obj_grad = None if output_obj_grad is None: - optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=True) + optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( @@ -587,31 +587,27 @@ def schedule_f( accum_loss=accum_loss, outputs=outputs, ) - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # We should not detach bwd LOSS - pass - else: - detached_output_obj = tree_map(detach, output_obj) # Step3: send fwd # add output to send_fwd_buffer if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): + detached_output_obj = tree_map(detach, output_obj) self.local_send_forward_buffer.append(detached_output_obj) + # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj + # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) else: - self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + self.send_forward_buffer[model_chunk_id].append(output_obj) else: # is first stage; end of fwd; append LOSS to local_send_backward_buffer if self.stage_manager.is_first_stage(ignore_chunk=True): pass else: - self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + self.send_forward_buffer[model_chunk_id].append(output_obj) # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) - # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) @@ -622,9 +618,6 @@ def schedule_b( model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - # input_obj: Optional[dict], - # output_obj: Union[dict, torch.Tensor], - # output_obj_grad: Optional[dict], ): """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 6fe7e0d7a920..04b962a513db 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -32,10 +32,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config ) - print(f"{sharded_model=}") if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() - sharded_model.unwrap().gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False}) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -302,7 +301,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "zero_stage": 0, "initial_scale": 1, - "enable_gradient_checkpointing": False, + "enable_gradient_checkpointing": True, "parallel_output": False, }, ], From eef5d83ba97fbe9e8d46cf4676f3db1b1c0d68f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 04:00:18 +0000 Subject: [PATCH 25/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../booster/plugin/hybrid_parallel_plugin.py | 15 +++++++---- .../pipeline/schedule/zero_bubble_pp.py | 10 ++++--- colossalai/shardformer/policies/llama.py | 4 +-- colossalai/zero/gemini/gemini_optimizer.py | 4 ++- colossalai/zero/low_level/low_level_optim.py | 7 ++++- .../test_model/test_shard_llama.py | 26 +++++++++---------- 6 files changed, 41 insertions(+), 25 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 791281506bff..3a78db6ac4fa 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -39,7 +39,6 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle -from colossalai.pipeline.schedule.v_schedule import PipelineGraph from .pp_plugin_base import PipelinePluginBase @@ -1031,7 +1030,9 @@ def __init__( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" - assert not pp_style == "zbv" or scheduler_nodes is not None, f"scheduler_nodes must not be None when using zero bubble pipeline." + assert ( + not pp_style == "zbv" or scheduler_nodes is not None + ), f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1096,8 +1097,12 @@ def __init__( assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" - assert pp_style in ["interleaved", "zbv"] or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" - assert pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2, "num_model_chunks must be 2 when using zero bubble pipeline" + assert ( + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -1135,7 +1140,7 @@ def __init__( elif pp_style == "zbv": self.scheduler = ZeroBubbleVPipeScheduler( stage_manager=self.stage_manager, - schedule = scheduler_nodes, + schedule=scheduler_nodes, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, microbatch_size=microbatch_size, diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 6c30b32dad38..3b6d88a3806b 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,7 +12,7 @@ from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device, model_forward +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -35,7 +35,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): return assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + out.data = torch.empty( + (1,), + device=out.device, + dtype=out.dtype, + ) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -481,7 +485,7 @@ def backward_b_step( if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - + if "hidden_states" in input_obj.keys(): input_obj_ = input_obj["hidden_states"] else: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 074734f761cf..60da448d8767 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -262,7 +262,7 @@ def get_held_layers(self) -> List[Module]: for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.norm) + held_layers.append(module.norm) elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.norm) @@ -356,7 +356,7 @@ def get_held_layers(self) -> List[Module]: if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) elif stage_manager.is_last_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) + held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 2283afd1eb42..ccd4634b5fe2 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,7 +298,9 @@ def backward(self, loss: torch.Tensor): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False): + def backward_by_grad( + self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False + ): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 2587e26944f8..9cc44c7538dd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -434,7 +434,12 @@ def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bo if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad, inputs=inputs, retain_graph=retain_graph,) + torch.autograd.backward( + tensor, + grad, + inputs=inputs, + retain_graph=retain_graph, + ) if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 04b962a513db..b74365c7f2af 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -7,6 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter @@ -22,7 +23,6 @@ run_forward_backward_with_hybrid_plugin, unwrap_model, ) -from colossalai.pipeline.schedule.v_schedule import PipelineGraph os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" @@ -121,7 +121,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager.is_first_stage(ignore_chunk=True): check_flag = True elif stage_manager.is_last_stage(ignore_chunk=True): - check_flag = True + check_flag = True if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 @@ -310,18 +310,18 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") if test_config.get("pp_style", None) == "zbv": mem_f = 34 * 32 + 5 * 4 * 16 - mem_w = - 32 * 32 - mem_b = - mem_w - mem_f + mem_w = -32 * 32 + mem_b = -mem_w - mem_f scheduler_nodes = PipelineGraph( - n_stage = test_config["pp_size"], - n_micro=test_config["num_microbatches"], - f_cost=1000, - b_cost=1000, - w_cost=1000, - c_cost=1, - f_mem=mem_f, - b_mem=mem_b, - w_mem=mem_w, + n_stage=test_config["pp_size"], + n_micro=test_config["num_microbatches"], + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, ).get_v_schedule() test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): From b55279c7ecf802d594d4e7ed5cdfb8aed2a2c7de Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 11:29:46 +0000 Subject: [PATCH 26/47] fix --- colossalai/pipeline/schedule/zero_bubble_pp.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 3b6d88a3806b..3700b8c95f7a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -35,11 +35,7 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): return assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty( - (1,), - device=out.device, - dtype=out.dtype, - ) + out.data.untyped_storage().resize_(0) class ZeroBubbleVPipeScheduler(PipelineSchedule): @@ -336,6 +332,7 @@ def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) + tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_tensor) return send_handles else: @@ -354,6 +351,7 @@ def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_tensor, prev_rank) + tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_tensor) return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: @@ -597,10 +595,11 @@ def schedule_f( if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): - detached_output_obj = tree_map(detach, output_obj) + output_obj_clone = tree_map(torch.Tensor.clone, output_obj) + detached_output_obj = tree_map(detach, output_obj_clone) self.local_send_forward_buffer.append(detached_output_obj) # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - # tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) + tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) else: self.send_forward_buffer[model_chunk_id].append(output_obj) else: From 3efd8d4744f897e55378783668e3657a7253009a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:58:15 +0000 Subject: [PATCH 27/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 +++- tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3a78db6ac4fa..5d114ab9c315 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1110,7 +1110,9 @@ def __init__( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" if pp_style == "zbv": - self.logger.warning("""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""") + self.logger.warning( + """the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""" + ) self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b74365c7f2af..d925687cd875 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -34,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() - sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False}) + sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster From 1839030b5c4b5e1b93b7e98a766601781e96b292 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Sep 2024 20:06:41 +0800 Subject: [PATCH 28/47] Update zero_bubble_pp.py --- colossalai/pipeline/schedule/zero_bubble_pp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 3700b8c95f7a..759a6144c169 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -484,11 +484,7 @@ def backward_b_step( # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - if "hidden_states" in input_obj.keys(): - input_obj_ = input_obj["hidden_states"] - else: - input_obj_ = input_obj["input_ids"] - + input_obj_ = input_obj["hidden_states"] if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: From 433c8a9194f45ac07c03049d794763323510f281 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Sep 2024 02:51:28 +0000 Subject: [PATCH 29/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 24 +++++++++++-------- .../test_model/test_shard_llama.py | 2 +- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 759a6144c169..5b4092bcefe3 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -11,6 +11,8 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.zero.low_level import LowLevelZeroOptimizer +from contextlib import nullcontext from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -485,16 +487,18 @@ def backward_b_step( assert output_obj_grad is None input_obj_ = input_obj["hidden_states"] - if output_obj_grad is None: - optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) - else: - output_obj_ = output_obj["hidden_states"] - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad, - inputs=input_obj_, - retain_graph=True, - ) + ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext() + with ctx: + if output_obj_grad is None: + optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) + else: + output_obj_ = output_obj["hidden_states"] + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad, + inputs=input_obj_, + retain_graph=True, + ) return input_obj_.grad def backward_w_step( diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index d925687cd875..9f67ecbea687 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -299,7 +299,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 4, "enable_all_optimization": False, "precision": "fp16", - "zero_stage": 0, + "zero_stage": 1, "initial_scale": 1, "enable_gradient_checkpointing": True, "parallel_output": False, From 3fb1e429364edcbecb4f40c86dd518e72fc7c761 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Sep 2024 06:13:22 +0000 Subject: [PATCH 30/47] fix-ci --- .github/workflows/build_on_pr.yml | 2 +- .github/workflows/build_on_schedule.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 58cd8826809a..79d758c87976 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . pip install --no-cache-dir -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index fc688a71bd92..e7b5063279eb 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -55,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install --no-cache-dir -r requirements/requirements-test.txt From cd2e34bc2206565265484a807996a2e65bb6b031 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 13 Sep 2024 08:21:36 +0000 Subject: [PATCH 31/47] fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix --- colossalai/pipeline/schedule/zero_bubble_pp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 5b4092bcefe3..7cff690d248e 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -11,9 +12,6 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.zero.low_level import LowLevelZeroOptimizer -from contextlib import nullcontext - from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -487,7 +485,13 @@ def backward_b_step( assert output_obj_grad is None input_obj_ = input_obj["hidden_states"] - ctx = optimizer.no_sync() if isinstance(optimizer, LowLevelZeroOptimizer) else nullcontext() + + # Attempt to disable gradient synchronization when using the LowLevelZeroPlugin. + try: + ctx = optimizer.no_sync() + except Exception as e: + ctx = nullcontext() + with ctx: if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) From 4e0f212ed42d278ebe1b868c8b75654b554b9477 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:30:21 +0000 Subject: [PATCH 32/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 36 +- .../test_schedule/test_zerobubble_pp.py | 857 +++--------------- tests/test_shardformer/test_model/_utils.py | 2 +- .../test_model/test_shard_llama.py | 272 +++--- 4 files changed, 304 insertions(+), 863 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7cff690d248e..097b1179693a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -157,8 +157,8 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.num_microbatch % self.stage_manager.num_stages == 0 ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" - if self.forward_only: - self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + # if self.forward_only: + # self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) # if self.batch_size != self.last_batch_size: # self.enable_metadata_cache = False @@ -435,15 +435,24 @@ def forward_step( micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) with self.stage_manager.switch_model_chunk_id(model_chunk_id): - # fwd calculate - internal_inputs = {} if input_obj is None else input_obj - internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + if isinstance(model_chunk, ModuleList): + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + else: + # fwd calculate + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + # last layer in model if self.stage_manager.is_last_stage(): + print(f"aaaaa{output_obj=}") loss = criterion(output_obj, micro_batch) / self.num_microbatch + print(f"bbbb{loss=}") if accum_loss is not None: - accum_loss.add_(loss.detach()) + print(f"accum_loss{accum_loss=}") + if not torch.isinf(loss): + accum_loss.add_(loss.detach()) + print(f"add accum_loss{accum_loss=}") if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss @@ -497,6 +506,7 @@ def backward_b_step( optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: output_obj_ = output_obj["hidden_states"] + print(f"{model_chunk_id=}, {output_obj_grad=}") optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad, @@ -531,6 +541,8 @@ def backward_w_step( # loss backward; output_obj is loss output_obj_grad = None + if isinstance(model_chunk, ModuleList): + model_chunk = model_chunk[model_chunk_id] if output_obj_grad is None: optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: @@ -742,9 +754,10 @@ def run_forward_only( outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None # while we still have schedules_node in self.schedules - for it in range(len(self.schedules)): - scheduled_node = self.schedules[it] - + # while we still have schedules_node in self.schedules + schedules = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + for it in range(len(schedules)): + scheduled_node = schedules[it] if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: # communication communication_func = self.communication_map[scheduled_node.type] @@ -761,6 +774,7 @@ def run_forward_only( # return loss & output if outputs is not None: outputs = merge_batch(outputs) + print(f"{accum_loss=}") return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( @@ -857,6 +871,6 @@ def forward_backward_step( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) - self.assert_buffer_empty() + # self.assert_buffer_empty() return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 825c192d8fd5..0bde70ec6c16 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,5 +1,6 @@ -from copy import deepcopy -from typing import Tuple +import copy +from functools import partial +from types import MethodType import pytest import torch @@ -10,18 +11,20 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper -from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + +NUM_LAYER = 8 +DIM = 4 class MlpModel(nn.Module): - def __init__(self, in_dim, out_dim, num_layers): + def __init__(self): super().__init__() - self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)]) def forward(self, x): for layer in self.layers: @@ -29,741 +32,163 @@ def forward(self, x): return x -def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: - num_params = 0 - num_params_trainable = 0 - for p in model.parameters(): - num_params += p.numel() - if p.requires_grad: - num_params_trainable += p.numel() - return num_params, num_params_trainable - - -# 1) Test manual v_schedule with multiple microbatch -@parameterize( - "test_config", - [ - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - ], -) -def run_fwd_bwd_iter_input(test_config): - # init dist - rank = dist.get_rank() - pp_size = test_config["pp_size"] - pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = test_config["num_microbatches"] - num_model_chunk = test_config["num_model_chunk"] - # stage_manager +def pp_linear_fwd( + forward, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, +): + with stage_mgr.switch_model_chunk_id(model_chunk_id): + if stage_mgr.is_first_stage(): + return {"hidden_states": forward(data)} + elif stage_mgr.is_last_stage(): + return forward(hidden_states) + else: + return {"hidden_states": forward(hidden_states)} + + +def run_pp( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, +): + """ + This test is to examine the correctness of interleaved 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + # create model + seed_all(1453) + torch_model = MlpModel().cuda() + pp_model = copy.deepcopy(torch_model).cuda() + + pg_mesh = ProcessGroupMesh(world_size) stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + pg_mesh, pipeline_axis=0, enable_interleave=True, use_zbv=True, num_model_chunks=num_model_chunk ) # schedule list - zbv_schedule = [ - # stage 0 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=0, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=0, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), - ], - # stage 1 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=1, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=1, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), - ], - # stage 2 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=2, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=2, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3), - ], - # stage 3 - [ - # microbatch 0 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), - # microbatch 1 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=1), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1), - # microbatch 2 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=2), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2), - # microbatch 3 - # chunk 0 fwd - ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="F", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3), - # chunk 1 fwd - ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="F", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3), - # chunk 1 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="B", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="W", chunk=1, stage=3, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3), - # chunk 0 bwd - ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="B", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="W", chunk=0, stage=3, minibatch=3), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3), - ], - ] - - scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? - stage_manager=stage_manager, - num_model_chunks=pp_size, - num_microbatch=num_microbatch, - overlap_p2p=False, - ) - - # loss func - def criterion(x, *args, **kwargs): - return (x * x).mean() - - # init model and input - batch_size = 4 - num_layers = 8 - in_dim = out_dim = 8 - print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") - model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - - input_base = [t.clone() for t in data_iter] - model_base = deepcopy(model) - - if rank == 0: - # layer 0 & 7 to chunk 0 on rank0 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 0 or idx == 7: - local_chunk.append(sub_model) - elif rank == 1: - # layer 1 & 6 to chunk 1 on rank1 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 1 or idx == 6: - local_chunk.append(sub_model) - elif rank == 2: - # layer 2 & 5 to chunk 2 on rank2 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 2 or idx == 5: - local_chunk.append(sub_model) - else: - # layer 3 & 4 to chunk 3 on rank3 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 3 or idx == 4: - local_chunk.append(sub_model) - # init optimizer - optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) - optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) - - print( - f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" - ) - - torch.cuda.synchronize() - result = scheduler.forward_backward_step( - model_chunk=local_chunk, - data_iter=iter(data_iter), - criterion=criterion, - optimizer=optimizer_pp, - return_loss=True, - return_outputs=True, - ) - - optimizer_pp.step() - - ########################## - # Fwd bwd for base - ########################## - # fwd & bwd - output_base = model_base(input_base[0]) - loss_base = criterion(output_base) - loss_base.backward() - optimizer_base.step() - print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - ########################## - # assert weight - ########################## - if rank == 0: - # layer 0 - assert_close(local_chunk[0].weight, model_base.layers[0].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) - # layer 7 - assert_close(local_chunk[1].weight, model_base.layers[7].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) - if rank == 1: - # layer 1 - assert_close(local_chunk[0].weight, model_base.layers[1].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) - # layer 6 - assert_close(local_chunk[1].weight, model_base.layers[6].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: - # layer 2 - assert_close(local_chunk[0].weight, model_base.layers[2].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) - # layer 5 - assert_close(local_chunk[1].weight, model_base.layers[5].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: - # layer 3 - assert_close(local_chunk[0].weight, model_base.layers[3].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) - # layer 4 - assert_close(local_chunk[1].weight, model_base.layers[4].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) - - -# 2) add optimizer base 1) -@parameterize( - "test_config", - [ - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 8, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - ], -) -def run_fwd_bwd_vschedule_with_optim(test_config): - # init dist - rank = dist.get_rank() - pp_size = test_config["pp_size"] - pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = test_config["num_microbatches"] - num_model_chunk = test_config["num_model_chunk"] - # stage_manager - stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk - ) - - h, a, s = 4096, 32, 1024 - mem_f = 34 * h + 5 * a * s - mem_w = -32 * h + mem_f = 34 * 32 + 5 * 4 * 16 + mem_w = -32 * 32 mem_b = -mem_w - mem_f - graph = PipelineGraph( - n_stage=pp_size, - n_micro=num_microbatch, - f_cost=1, - b_cost=1, - w_cost=1, + scheduler_nodes = PipelineGraph( + n_stage=4, + n_micro=12, + f_cost=1000, + b_cost=1000, + w_cost=1000, c_cost=1, f_mem=mem_f, b_mem=mem_b, w_mem=mem_w, - # max_mem=mem_f * (p * 2 + m_offset), - ) - - zbv_schedule = graph.get_v_schedule() - - scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + ).get_v_schedule() + schedule = ZeroBubbleVPipeScheduler( stage_manager=stage_manager, + schedule=scheduler_nodes, num_model_chunks=num_model_chunk, num_microbatch=num_microbatch, - overlap_p2p=False, ) - # init loss func + sharded_model = torch.nn.ModuleList() + for idx, sub_model in enumerate(pp_model.layers): + if idx == rank or (NUM_LAYER-idx-1) == rank: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)), + sub_model._forward, + ) + sharded_model.append(sub_model.cuda()) + assert len(sharded_model) == num_model_chunk, f"{len(sharded_model)}, {num_model_chunk}, num_model_chunk is not correct" + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5)) + + # create data + seed_all(115) + input_list = [torch.rand(batch_size, DIM).cuda()] + dist.all_reduce(input_list[0]) + def criterion(x, *args, **kwargs): return (x * x).mean() - # init model and input - batch_size = test_config["batch_size"] - num_layers = 8 - assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" - in_dim = out_dim = 4096 - before_init_memory = torch.cuda.memory_allocated() / 1024**3 - print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") - model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - - input_base = [t.clone() for t in data_iter] - model_base = deepcopy(model) - - if rank == 0: - # layer 0 & 7 to chunk 0 on rank0 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 0 or idx == 7: - local_chunk.append(sub_model) - elif rank == 1: - # layer 1 & 6 to chunk 1 on rank1 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 1 or idx == 6: - local_chunk.append(sub_model) - elif rank == 2: - # layer 2 & 5 to chunk 2 on rank2 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 2 or idx == 5: - local_chunk.append(sub_model) - else: - # layer 3 & 4 to chunk 3 on rank3 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 3 or idx == 4: - local_chunk.append(sub_model) - - # init optimizer - optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) - optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) - - after_init_memory = torch.cuda.memory_allocated() / 1024**3 - print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") - - torch.cuda.synchronize() - result = scheduler.forward_backward_step( - model_chunk=local_chunk, - data_iter=iter(data_iter), - criterion=criterion, - optimizer=optimizer_pp, - return_loss=True, - return_outputs=True, - ) - - optimizer_pp.step() - - after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 - - # assert memory - if rank != 0: - # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 - # output hid_dim * hid_dim * 4(fp32) / 1024**3 - # optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 - print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}") - assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3) - else: - # rank0 will also hold output; - print( - f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) + + # check loss + if stage_manager.is_first_stage(ignore_chunk=True): + assert_close(torch_loss, pp_ret["loss"]) + + # check gradients + for i in range(num_model_chunk): + # idx = world_size * i + rank + if i == 0: + idx = rank + else: + idx = world_size * 2 - rank - 1 + print(f"{i=}, {idx=}, {rank=}, {torch_model.layers[idx].weight.grad=}, {sharded_model[i].weight.grad=}") + assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) + + # step + torch_optimizer.step() + pp_optimizer.step() + pp_optimizer.zero_grad() + + # check updated param + for i in range(num_model_chunk): + # idx = world_size * i + rank + if i == 0: + idx = rank + else: + idx = world_size * 2 - rank - 1 + assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) + assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) + + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) - assert round((after_pp_step_memory - after_init_memory), 5) <= round( - (in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 - ) - - ########################## - # Fwd bwd for base - ########################## - # fwd & bwd - output_base = model_base(input_base[0]) - loss_base = criterion(output_base) - loss_base.backward() - optimizer_base.step() - - ########################## - # assert loss & output - ########################## - # only chunk 1 stage 0 hold loss and output - if rank == 0: - assert_close(result["loss"], loss_base) - assert_close(result["outputs"], output_base) + if stage_manager.is_first_stage(ignore_chunk=True): + print(f"{torch_loss=}, {pp_ret['loss']}") + assert_close(torch_loss, pp_ret["loss"]) - # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") - ########################## - # assert weight - ########################## - if rank == 0: - # layer 0 - assert_close(local_chunk[0].weight, model_base.layers[0].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) - # layer 7 - assert_close(local_chunk[1].weight, model_base.layers[7].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) - if rank == 1: - # layer 1 - assert_close(local_chunk[0].weight, model_base.layers[1].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) - # layer 6 - assert_close(local_chunk[1].weight, model_base.layers[6].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: - # layer 2 - assert_close(local_chunk[0].weight, model_base.layers[2].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) - # layer 5 - assert_close(local_chunk[1].weight, model_base.layers[5].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: - # layer 3 - assert_close(local_chunk[0].weight, model_base.layers[3].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) - # layer 4 - assert_close(local_chunk[1].weight, model_base.layers[4].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) - - ########################## - # assert optim state - ########################## - optim_base_state = optimizer_base.state_dict()["state"] - optim_pp_state = optimizer_pp.state_dict()["state"] - optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] - optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] - # if rank == 0: - # print(f"optim_base_state {optim_base_state}") - - # assert param group - for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): - if key_base == key_pp: - if key_base != "params": - assert val_base == val_pp + for layer in sharded_model: + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None else: - # BUG: - # param_base: [0, 1, 2, 3, 4, 5, 6, 7]; - # params pp: [0, 1]; - assert val_base[:2] == val_pp - - # assert state - assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) - assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"]) - - -# TODO:4) support Hybrid base 3) -def run_with_hybridplugin(test_config): - pass - - -# TODO:5) support MoEHybrid base 3) -@parameterize( - "test_config", - [ - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, - ], -) -def run_with_moehybridplugin(test_config): - model_zoo.get_sub_registry("transformers_bert") - test_config["use_lazy_init"] = False - test_config["initial_scale"] = 2**16 - model_list = [ - "transformers_bert", - ] - - -# TODO:6) support booster & Hybrid base 4) - -# TODO:7) support booster & MoEHybrid base 4) - - -def run_dist(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # run_fwd_bwd_iter_input() - run_fwd_bwd_vschedule_with_optim() - # run_with_moehybridplugin() + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) @pytest.mark.dist +@pytest.mark.parametrize("num_microbatch", [12]) +@pytest.mark.parametrize("batch_size", [24]) +@pytest.mark.parametrize("num_model_chunk", [2]) @rerun_if_address_is_in_use() -def test_pp(): +def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): + assert NUM_LAYER % num_model_chunk == 0 spawn( - run_dist, - nprocs=4, + run_pp, + nprocs=NUM_LAYER // num_model_chunk, + num_microbatch=num_microbatch, + batch_size=batch_size, + num_model_chunk=num_model_chunk, ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=12, num_model_chunk=2) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 2c7a3866f784..dbefe872cbc8 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -397,7 +397,7 @@ def get_grad_tensors_for_check( pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - + print(f"grad_to_check {shard_grad=}") grad_to_check[suffix] = { "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 9f67ecbea687..81892817cec4 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -136,6 +136,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, rtol=rtol, shard_config=booster.plugin.shard_config, ) + print(f"check_loss") check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): @@ -143,6 +144,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 + print(f"check_weight") check_weight( llama_model, shard_llama_model, @@ -162,135 +164,135 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # Double Ring Attention - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 4, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - "inner_ring_size": 2, - }, - # Ring Attention + PP - { - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - # Ring Attention + TP - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + TP - "tp_size": 2, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + PP - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "use_lazy_init": False, - "precision": "fp32", - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, + # # Double Ring Attention + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 4, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # "inner_ring_size": 2, + # }, + # # Ring Attention + PP + # { + # "tp_size": 1, + # "pp_size": 2, + # "sp_size": 2, + # "num_microbatches": 2, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # # Ring Attention + TP + # { + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { # Ulysess + TP + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { # Ulysess + PP + # "tp_size": 1, + # "pp_size": 2, + # "sp_size": 2, + # "num_microbatches": 2, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "split_gather", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 4, + # "use_lazy_init": False, + # "precision": "fp32", + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, { "tp_size": 2, "pp_size": 2, @@ -301,7 +303,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "zero_stage": 1, "initial_scale": 1, - "enable_gradient_checkpointing": True, + "enable_gradient_checkpointing": False, "parallel_output": False, }, ], @@ -410,11 +412,11 @@ def test_llama(): spawn(check_llama, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama_3d(): - spawn(check_llama_3d, 8) +# @pytest.mark.largedist +# @rerun_if_address_is_in_use() +# @clear_cache_before_run() +# def test_llama_3d(): +# spawn(check_llama_3d, 8) if __name__ == "__main__": From fa358b26e1a4d7d9339f670e016fe3cd4c4cda1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 08:26:56 +0000 Subject: [PATCH 33/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/pipeline/schedule/zero_bubble_pp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 097b1179693a..c76656200a2f 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,6 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager + from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -498,7 +499,7 @@ def backward_b_step( # Attempt to disable gradient synchronization when using the LowLevelZeroPlugin. try: ctx = optimizer.no_sync() - except Exception as e: + except Exception: ctx = nullcontext() with ctx: From 0bde0bfa3e7791b4ff62118fae1d62ac869ab192 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:42:14 +0000 Subject: [PATCH 34/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 6 - tests/test_shardformer/test_model/_utils.py | 1 - .../test_model/test_shard_llama.py | 272 +++++++++--------- 3 files changed, 135 insertions(+), 144 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c76656200a2f..201298cd65c2 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -446,14 +446,10 @@ def forward_step( # last layer in model if self.stage_manager.is_last_stage(): - print(f"aaaaa{output_obj=}") loss = criterion(output_obj, micro_batch) / self.num_microbatch - print(f"bbbb{loss=}") if accum_loss is not None: - print(f"accum_loss{accum_loss=}") if not torch.isinf(loss): accum_loss.add_(loss.detach()) - print(f"add accum_loss{accum_loss=}") if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss @@ -507,7 +503,6 @@ def backward_b_step( optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: output_obj_ = output_obj["hidden_states"] - print(f"{model_chunk_id=}, {output_obj_grad=}") optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad, @@ -775,7 +770,6 @@ def run_forward_only( # return loss & output if outputs is not None: outputs = merge_batch(outputs) - print(f"{accum_loss=}") return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index dbefe872cbc8..5c141e8f5cf1 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -397,7 +397,6 @@ def get_grad_tensors_for_check( pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - print(f"grad_to_check {shard_grad=}") grad_to_check[suffix] = { "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 81892817cec4..9f67ecbea687 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -136,7 +136,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, rtol=rtol, shard_config=booster.plugin.shard_config, ) - print(f"check_loss") check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): @@ -144,7 +143,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - print(f"check_weight") check_weight( llama_model, shard_llama_model, @@ -164,135 +162,135 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # # Double Ring Attention - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 4, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # "inner_ring_size": 2, - # }, - # # Ring Attention + PP - # { - # "tp_size": 1, - # "pp_size": 2, - # "sp_size": 2, - # "num_microbatches": 2, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # # Ring Attention + TP - # { - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { # Ulysess + TP - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { # Ulysess + PP - # "tp_size": 1, - # "pp_size": 2, - # "sp_size": 2, - # "num_microbatches": 2, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "split_gather", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 4, - # "use_lazy_init": False, - # "precision": "fp32", - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, + # Double Ring Attention + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + "inner_ring_size": 2, + }, + # Ring Attention + PP + { + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + # Ring Attention + TP + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + TP + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + PP + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -303,7 +301,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "zero_stage": 1, "initial_scale": 1, - "enable_gradient_checkpointing": False, + "enable_gradient_checkpointing": True, "parallel_output": False, }, ], @@ -412,11 +410,11 @@ def test_llama(): spawn(check_llama, 4) -# @pytest.mark.largedist -# @rerun_if_address_is_in_use() -# @clear_cache_before_run() -# def test_llama_3d(): -# spawn(check_llama_3d, 8) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) if __name__ == "__main__": From a2f187bcab1a2236a17a32fb8415eb93a817cee2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 02:32:55 +0000 Subject: [PATCH 35/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/pipeline/schedule/zero_bubble_pp.py | 16 ++++++++-------- .../test_schedule/test_zerobubble_pp.py | 8 +++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 201298cd65c2..20ea46b77adc 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -160,13 +160,13 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) # if self.forward_only: # self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 - # NOTE: disable metadata cache when batch size changes (not valid anymore) - # if self.batch_size != self.last_batch_size: - # self.enable_metadata_cache = False - # self.send_tensor_metadata = True - # self.send_grad_metadata = True - # self.tensor_metadata_recv = None - # self.grad_metadata_recv = None + # NOTE: disable metadata cache when batch size changes (not valid anymore) + # if self.batch_size != self.last_batch_size: + # self.enable_metadata_cache = False + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None self.last_batch_size = self.batch_size @@ -443,7 +443,7 @@ def forward_step( internal_inputs = {} if input_obj is None else input_obj internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - + # last layer in model if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatch diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 0bde70ec6c16..31531069386e 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -11,7 +11,7 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper -from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -96,14 +96,16 @@ def run_pp( sharded_model = torch.nn.ModuleList() for idx, sub_model in enumerate(pp_model.layers): - if idx == rank or (NUM_LAYER-idx-1) == rank: + if idx == rank or (NUM_LAYER - idx - 1) == rank: sub_model._forward = sub_model.forward sub_model.forward = MethodType( partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)), sub_model._forward, ) sharded_model.append(sub_model.cuda()) - assert len(sharded_model) == num_model_chunk, f"{len(sharded_model)}, {num_model_chunk}, num_model_chunk is not correct" + assert ( + len(sharded_model) == num_model_chunk + ), f"{len(sharded_model)}, {num_model_chunk}, num_model_chunk is not correct" # create optimizer torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) From a9ac43860b503c1e3714aa4501ba938d5ea34434 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:47:33 +0000 Subject: [PATCH 36/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 2 +- .../test_schedule/test_zerobubble_pp.py | 28 +++++++------------ 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 20ea46b77adc..982c95c6a727 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -866,6 +866,6 @@ def forward_backward_step( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) - # self.assert_buffer_empty() + self.assert_buffer_empty() return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 31531069386e..6881893f157a 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -156,24 +156,16 @@ def criterion(x, *args, **kwargs): assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) - # forward only - with torch.no_grad(): - torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output) - - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True - ) - if stage_manager.is_first_stage(ignore_chunk=True): - print(f"{torch_loss=}, {pp_ret['loss']}") - assert_close(torch_loss, pp_ret["loss"]) - - for layer in sharded_model: - if layer.weight.grad is None: - assert layer.weight.grad is None and layer.bias.grad is None - else: - assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + # forward one step + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True + ) + if stage_manager.is_first_stage(ignore_chunk=True): + print(f"{torch_loss=}, {pp_ret['loss']}") + assert_close(torch_loss, pp_ret["loss"]) @pytest.mark.dist From 8c3cdea7d8c59953a9b74ab128f13862fd9a0b6a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 02:52:22 +0000 Subject: [PATCH 37/47] fix --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 6881893f157a..ce1e869c892c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -132,7 +132,6 @@ def criterion(x, *args, **kwargs): # check gradients for i in range(num_model_chunk): - # idx = world_size * i + rank if i == 0: idx = rank else: @@ -148,7 +147,6 @@ def criterion(x, *args, **kwargs): # check updated param for i in range(num_model_chunk): - # idx = world_size * i + rank if i == 0: idx = rank else: @@ -160,9 +158,7 @@ def criterion(x, *args, **kwargs): torch_output = torch_model(input_list[0]) torch_loss = criterion(torch_output) - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True - ) + pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) if stage_manager.is_first_stage(ignore_chunk=True): print(f"{torch_loss=}, {pp_ret['loss']}") assert_close(torch_loss, pp_ret["loss"]) From f07b3051d2954968035129671680ef78f12d957e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 04:15:12 +0000 Subject: [PATCH 38/47] fix --- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 6 +++--- .../test_pipeline_utils/test_t5_pipeline_utils.py | 1 + .../test_pipeline_utils/test_whisper_pipeline_utils.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 36973b240896..6fbb72ba3e5b 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -278,7 +278,7 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: @@ -300,7 +300,7 @@ def __init__( if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -309,7 +309,7 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index e2f71ff89221..f79bdeb3a96f 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -15,6 +15,7 @@ def __init__(self): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index d39c5ea91dd4..722b8fd7cfae 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -15,6 +15,7 @@ def __init__(self): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): From 29ec566ecee0bf683d90b258cda6c9cb870b6d90 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 10:22:16 +0000 Subject: [PATCH 39/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 59 ++++++++++--------- .../test_schedule/test_zerobubble_pp.py | 27 ++++++--- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 982c95c6a727..8ee41ae753c8 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -5,7 +5,7 @@ import torch import torch.cuda from torch.nn import Module, ModuleList -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_map from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper @@ -109,22 +109,15 @@ def _free_buffers(self): def assert_buffer_empty(self): # assert buuffer is empty at end - assert len(self.input_tensors[0]) == 0 - assert len(self.input_tensors[1]) == 0 - assert len(self.output_tensors[0]) == 0 - assert len(self.output_tensors[1]) == 0 - assert len(self.output_tensors_dw[0]) == 0 - assert len(self.output_tensors_dw[1]) == 0 - assert len(self.output_tensors_grad_dw[0]) == 0 - assert len(self.output_tensors_grad_dw[1]) == 0 - assert len(self.send_forward_buffer[0]) == 0 - assert len(self.send_forward_buffer[1]) == 0 - assert len(self.recv_forward_buffer[0]) == 0 - assert len(self.recv_forward_buffer[1]) == 0 - assert len(self.send_backward_buffer[0]) == 0 - assert len(self.send_backward_buffer[1]) == 0 - assert len(self.recv_backward_buffer[0]) == 0 - assert len(self.recv_backward_buffer[1]) == 0 + for chunk_id in [0, 1]: + assert len(self.input_tensors[chunk_id]) == 0, f"{self.input_tensors=}" + assert len(self.output_tensors[chunk_id]) == 0 + assert len(self.output_tensors_dw[chunk_id]) == 0 + assert len(self.output_tensors_grad_dw[chunk_id]) == 0 + assert len(self.send_forward_buffer[chunk_id]) == 0 + assert len(self.recv_forward_buffer[chunk_id]) == 0 + assert len(self.send_backward_buffer[chunk_id]) == 0 + assert len(self.recv_backward_buffer[chunk_id]) == 0 assert len(self.local_send_forward_buffer) == 0 assert len(self.local_send_backward_buffer) == 0 @@ -158,8 +151,8 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.num_microbatch % self.stage_manager.num_stages == 0 ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" - # if self.forward_only: - # self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 + if self.forward_only: + self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) # if self.batch_size != self.last_batch_size: # self.enable_metadata_cache = False @@ -490,8 +483,6 @@ def backward_b_step( # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - input_obj_ = input_obj["hidden_states"] - # Attempt to disable gradient synchronization when using the LowLevelZeroPlugin. try: ctx = optimizer.no_sync() @@ -499,17 +490,27 @@ def backward_b_step( ctx = nullcontext() with ctx: + input_obj_, tree_spec = tree_flatten(input_obj) if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: - output_obj_ = output_obj["hidden_states"] + output_obj_, _ = tree_flatten(output_obj) + output_obj_grad_, _ = tree_flatten(output_obj_grad) optimizer.backward_by_grad( tensor=output_obj_, - grad=output_obj_grad, + grad=output_obj_grad_, inputs=input_obj_, retain_graph=True, ) - return input_obj_.grad + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad def backward_w_step( self, @@ -539,13 +540,15 @@ def backward_w_step( if isinstance(model_chunk, ModuleList): model_chunk = model_chunk[model_chunk_id] + if output_obj_grad is None: optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: - output_obj_ = output_obj["hidden_states"] + output_obj_, _ = tree_flatten(output_obj) + output_obj_grad_, _ = tree_flatten(output_obj_grad) optimizer.backward_by_grad( tensor=output_obj_, - grad=output_obj_grad, + grad=output_obj_grad_, inputs=list(model_chunk.parameters()), retain_graph=False, ) @@ -861,11 +864,11 @@ def forward_backward_step( if self.forward_only: result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) + self._free_buffers() else: result = self.run_forward_backward( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) - - self.assert_buffer_empty() + self.assert_buffer_empty() return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ce1e869c892c..8ad51610bfa8 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -136,7 +136,6 @@ def criterion(x, *args, **kwargs): idx = rank else: idx = world_size * 2 - rank - 1 - print(f"{i=}, {idx=}, {rank=}, {torch_model.layers[idx].weight.grad=}, {sharded_model[i].weight.grad=}") assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) @@ -154,14 +153,24 @@ def criterion(x, *args, **kwargs): assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) - # forward one step - torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output) - - pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) - if stage_manager.is_first_stage(ignore_chunk=True): - print(f"{torch_loss=}, {pp_ret['loss']}") - assert_close(torch_loss, pp_ret["loss"]) + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True + ) + if stage_manager.is_first_stage(ignore_chunk=True): + assert_close(torch_loss, pp_ret["loss"]) + + for layer in sharded_model: + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None + else: + assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + torch.cuda.empty_cache() @pytest.mark.dist From b57e78ded6831351c278feca8bd25016504c047b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Sep 2024 11:28:52 +0000 Subject: [PATCH 40/47] fix --- .../pipeline/schedule/zero_bubble_pp.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 8ee41ae753c8..7b9ca8dd726c 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -153,13 +153,13 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) if self.forward_only: self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 - # NOTE: disable metadata cache when batch size changes (not valid anymore) - # if self.batch_size != self.last_batch_size: - # self.enable_metadata_cache = False - # self.send_tensor_metadata = True - # self.send_grad_metadata = True - # self.tensor_metadata_recv = None - # self.grad_metadata_recv = None + # NOTE: disable metadata cache when batch size changes (not valid anymore) + # if self.batch_size != self.last_batch_size: + # self.enable_metadata_cache = False + # self.send_tensor_metadata = True + # self.send_grad_metadata = True + # self.tensor_metadata_recv = None + # self.grad_metadata_recv = None self.last_batch_size = self.batch_size @@ -490,12 +490,14 @@ def backward_b_step( ctx = nullcontext() with ctx: - input_obj_, tree_spec = tree_flatten(input_obj) + input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) if output_obj_grad is None: optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) else: - output_obj_, _ = tree_flatten(output_obj) - output_obj_grad_, _ = tree_flatten(output_obj_grad) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, @@ -544,8 +546,10 @@ def backward_w_step( if output_obj_grad is None: optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) else: - output_obj_, _ = tree_flatten(output_obj) - output_obj_grad_, _ = tree_flatten(output_obj_grad) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, From 083ea313b5ac676e2c3e4f93d471d7be68f66fc3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 27 Sep 2024 03:39:04 +0000 Subject: [PATCH 41/47] fix --- colossalai/pipeline/schedule/_utils.py | 38 + .../pipeline/schedule/zero_bubble_pp.py | 275 ++--- .../test_schedule/test_zerobubble_pp.py | 957 +++++++++++++++--- .../test_model/test_shard_llama.py | 268 ++--- 4 files changed, 1138 insertions(+), 400 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 271b3238f5c4..b641eb3645cd 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None: x.retain_grad() +def require_grad(x: Any) -> None: + """Call require_grad on a tensor. + + Args: + x (Any): Object to be called. + """ + if isinstance(x, torch.Tensor) and not x.requires_grad: + x.requires_grad_() + + def detach(x: Any) -> Any: """Call detach() on a tensor. @@ -145,6 +155,34 @@ def detach(x: Any) -> Any: return x +def clone(x: Any) -> Any: + """Call clone() on a tensor. + + Args: + x (Any): Object to be called. + + Returns: + Any: The cloned object. + """ + if isinstance(x, torch.Tensor): + return x.clone() + return x + + +def release_tensor_data(x: Any) -> Any: + """Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn. + + Args: + x (Any): Object to be called. + + Returns: + Any: The deallocate .data object. + """ + if isinstance(x, torch.Tensor): + return x.data.untyped_storage().resize_(0) + return x + + def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7b9ca8dd726c..5c25c5bfaa80 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -13,7 +12,18 @@ from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from ._utils import ( + clone, + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + release_tensor_data, + require_grad, + retain_grad, + to_device, +) from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -25,20 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: req.wait() -def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): - """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. - - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - """ - if (out is None) or (not deallocate_pipeline_outputs): - return - assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ - assert out._base is None, "counter-productive to free a view of another tensor." - out.data.untyped_storage().resize_(0) - - class ZeroBubbleVPipeScheduler(PipelineSchedule): def __init__( self, @@ -59,9 +55,10 @@ def __init__( self.batch_size: int self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] + + self.schedules = schedule # TODO: optim post valid self.do_post_validation = False - self.schedules = schedule # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache @@ -109,15 +106,22 @@ def _free_buffers(self): def assert_buffer_empty(self): # assert buuffer is empty at end - for chunk_id in [0, 1]: - assert len(self.input_tensors[chunk_id]) == 0, f"{self.input_tensors=}" - assert len(self.output_tensors[chunk_id]) == 0 - assert len(self.output_tensors_dw[chunk_id]) == 0 - assert len(self.output_tensors_grad_dw[chunk_id]) == 0 - assert len(self.send_forward_buffer[chunk_id]) == 0 - assert len(self.recv_forward_buffer[chunk_id]) == 0 - assert len(self.send_backward_buffer[chunk_id]) == 0 - assert len(self.recv_backward_buffer[chunk_id]) == 0 + assert len(self.input_tensors[0]) == 0 + assert len(self.input_tensors[1]) == 0 + assert len(self.output_tensors[0]) == 0 + assert len(self.output_tensors[1]) == 0 + assert len(self.output_tensors_dw[0]) == 0 + assert len(self.output_tensors_dw[1]) == 0 + assert len(self.output_tensors_grad_dw[0]) == 0 + assert len(self.output_tensors_grad_dw[1]) == 0 + assert len(self.send_forward_buffer[0]) == 0 + assert len(self.send_forward_buffer[1]) == 0 + assert len(self.recv_forward_buffer[0]) == 0 + assert len(self.recv_forward_buffer[1]) == 0 + assert len(self.send_backward_buffer[0]) == 0 + assert len(self.send_backward_buffer[1]) == 0 + assert len(self.recv_backward_buffer[0]) == 0 + assert len(self.recv_backward_buffer[1]) == 0 assert len(self.local_send_forward_buffer) == 0 assert len(self.local_send_backward_buffer) == 0 @@ -326,7 +330,6 @@ def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) - tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_tensor) return send_handles else: @@ -345,7 +348,6 @@ def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_tensor, prev_rank) - tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_tensor) return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: @@ -403,6 +405,7 @@ def forward_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, + micro_batch: Optional[dict], input_obj: Optional[dict], criterion: Callable, accum_loss: Optional[torch.Tensor] = None, @@ -421,28 +424,20 @@ def forward_step( Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ # Load input ids, attention mask and labels - # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) - - # for the first stage, input_obj is None + # for the first stage, input_obj is None; So,we use micro_batch as input_obj # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. # Only attention_mask from micro_batch is used - micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) - with self.stage_manager.switch_model_chunk_id(model_chunk_id): - if isinstance(model_chunk, ModuleList): - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) - else: - # fwd calculate - internal_inputs = {} if input_obj is None else input_obj - internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + # fwd calculate + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) # last layer in model - if self.stage_manager.is_last_stage(): + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: - if not torch.isinf(loss): - accum_loss.add_(loss.detach()) + accum_loss.add_(loss.detach()) if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss @@ -454,6 +449,7 @@ def backward_b_step( model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, + # micro_batch: Optional[dict], input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], @@ -464,7 +460,7 @@ def backward_b_step( model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model - input_obj (Optional[dict]): x. + input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj) output_obj (Union[dict, torch.Tensor]): y. output_obj_grad (dict): dy. @@ -473,42 +469,49 @@ def backward_b_step( """ # calculate bwd b step ; only dx = w*dy; - # Retain the grad on the input_obj. - if input_obj is None: - return None - else: + # Retain the grad on the input_obj. No need retain_grad microbatch + if input_obj is not None: tree_map(retain_grad, input_obj) - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # loss backward; output_obj is loss; so output_obj_grad should be None - assert output_obj_grad is None + # x, y, dy list for backward_by_grad; Type: list[tensor]; + input_obj_ = [] + output_obj_ = [] + output_obj_grad_ = [] - # Attempt to disable gradient synchronization when using the LowLevelZeroPlugin. - try: - ctx = optimizer.no_sync() - except Exception: - ctx = nullcontext() + # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + return None - with ctx: - input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) - if output_obj_grad is None: - optimizer.backward(output_obj, inputs=input_obj_, retain_graph=True) - else: - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad_, - inputs=input_obj_, - retain_graph=True, - ) + # For loss backward; output_obj is loss; output_obj_grad should be None + elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + assert output_obj_grad is None + input_obj_, _ = tree_flatten(input_obj) + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(output_obj_grad) # None - # Collect the grad of the input_obj. - input_obj_grad = None - if input_obj is not None: - input_obj_grad = {} + # For other chunk stage, use input_obj as input_obj_; + else: + input_obj_, _ = tree_flatten(input_obj) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] + + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=input_obj_, + retain_graph=True, + ) + + # Format output_obj_grad + input_obj_grad = {} + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + pass + else: for k, v in input_obj.items(): if isinstance(v, torch.Tensor) and v.grad is not None: input_obj_grad[k] = v.grad @@ -536,26 +539,28 @@ def backward_w_step( """ # calculate bwd w step ; only dw = x*dy; - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # loss backward; output_obj is loss - output_obj_grad = None + # y, dy list for w backward_by_grad; Type: list[tensor]; + output_obj_ = [] + output_obj_grad_ = [] - if isinstance(model_chunk, ModuleList): - model_chunk = model_chunk[model_chunk_id] - - if output_obj_grad is None: - optimizer.backward(output_obj, inputs=list(model_chunk.parameters()), retain_graph=False) + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # loss backward; output_obj is loss; + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(None) # None else: - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad_, - inputs=list(model_chunk.parameters()), - retain_graph=False, - ) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] + + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=list(model_chunk.parameters()), + retain_graph=False, + ) def schedule_f( self, @@ -579,9 +584,10 @@ def schedule_f( Returns: Nothing. """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # Step1: recv fwd if model_chunk_id == 0: - # is first stage; get input from func param + # is first stage; get input from microbatch if self.stage_manager.is_first_stage(ignore_chunk=True): input_obj = None else: @@ -595,44 +601,68 @@ def schedule_f( input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Here, let input_obj.requires_grad_() - # add input and output object for backward b - if input_obj is not None: - tree_map(torch.Tensor.requires_grad_, input_obj) + # if input_obj is not None: + if not isinstance(input_obj, torch.Tensor): + tree_map(require_grad, input_obj) + + # Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd, + # tree_map(torch.Tensor.requires_grad_, micro_batch) # Step2: fwd step output_obj = self.forward_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, + micro_batch=micro_batch, input_obj=input_obj, criterion=criterion, accum_loss=accum_loss, outputs=outputs, ) - # Step3: send fwd + # Step3: + # 3-1:detach output; detach output for send fwd; + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not detach bwd LOSS + pass + else: + # detach output + detached_output_obj = tree_map(detach, output_obj) + # 3-2 clone detached_output_obj + detached_output_obj = tree_map(clone, detached_output_obj) + + # 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output) + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not release_tensor_data bwd LOSS + pass + else: + # release_tensor_data output + tree_map(release_tensor_data, output_obj) + + # add input and output object for backward b + self.input_tensors[model_chunk_id].append(input_obj) + + # for bwd b&w, we only need the graph(grad_fn) of output_obj + # Do not release_tensor_data loss, release_tensor_data other output_obj; + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + self.output_tensors[model_chunk_id].append(output_obj) + self.output_tensors_dw[model_chunk_id].append(output_obj) + else: + self.output_tensors[model_chunk_id].append(output_obj) + self.output_tensors_dw[model_chunk_id].append(output_obj) + # add output to send_fwd_buffer - if model_chunk_id == 0: + if model_chunk_id == 0: # chunk 0 # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): - output_obj_clone = tree_map(torch.Tensor.clone, output_obj) - detached_output_obj = tree_map(detach, output_obj_clone) self.local_send_forward_buffer.append(detached_output_obj) - # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), output_obj) else: - self.send_forward_buffer[model_chunk_id].append(output_obj) - else: - # is first stage; end of fwd; append LOSS to local_send_backward_buffer + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) + else: # chunk 1 + # is first stage; end of fwd; do nothing if self.stage_manager.is_first_stage(ignore_chunk=True): pass else: - self.send_forward_buffer[model_chunk_id].append(output_obj) - - # add input and output object for backward b - self.input_tensors[model_chunk_id].append(input_obj) - self.output_tensors[model_chunk_id].append(output_obj) - # add output object for backward w - self.output_tensors_dw[model_chunk_id].append(output_obj) + self.send_forward_buffer[model_chunk_id].append(detached_output_obj) def schedule_b( self, @@ -650,20 +680,19 @@ def schedule_b( Returns: Nothing. """ - # Step1: recv bwd if model_chunk_id == 0: # chunk0 is last stage; recv output_grad from local_send_backward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): output_tensor_grad = self.local_send_backward_buffer.pop(0) - # chunk 0 not last stage; recv output_grad from recv_backward_buffer + # chunk0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): output_tensor_grad = None - # chunk1, not first stage; recv output_grad from recv_backward_buffer + # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) @@ -757,10 +786,9 @@ def run_forward_only( outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None # while we still have schedules_node in self.schedules - # while we still have schedules_node in self.schedules - schedules = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) - for it in range(len(schedules)): - scheduled_node = schedules[it] + for it in range(len(self.schedules)): + scheduled_node = self.schedules[it] + if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: # communication communication_func = self.communication_map[scheduled_node.type] @@ -811,8 +839,7 @@ def run_forward_backward( # communication communication_func = self.communication_map[scheduled_node.type] communication_func(scheduled_node.chunk) - - if scheduled_node.type == "F": + elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -868,11 +895,11 @@ def forward_backward_step( if self.forward_only: result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) - self._free_buffers() else: result = self.run_forward_backward( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) - self.assert_buffer_empty() + + self.assert_buffer_empty() return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 8ad51610bfa8..0f2d6c49c749 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,6 +1,6 @@ -import copy +from copy import deepcopy from functools import partial -from types import MethodType +from typing import Tuple import pytest import torch @@ -11,183 +11,856 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper -from colossalai.pipeline.schedule.v_schedule import PipelineGraph +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all - -NUM_LAYER = 8 -DIM = 4 +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo class MlpModel(nn.Module): - def __init__(self): + def __init__( + self, + in_dim, + out_dim, + num_layers, + stage_index=None, + stage_mgr: PipelineStageManager = None, + ): super().__init__() - self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)]) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -def pp_linear_fwd( - forward, - data: torch.Tensor = None, - hidden_states: torch.Tensor = None, - stage_mgr: PipelineStageManager = None, - model_chunk_id: int = None, -): - with stage_mgr.switch_model_chunk_id(model_chunk_id): - if stage_mgr.is_first_stage(): - return {"hidden_states": forward(data)} - elif stage_mgr.is_last_stage(): - return forward(hidden_states) + self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + + def forward( + self, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_index=None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, + ): + if stage_mgr is None: + hidden_states = data + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states else: - return {"hidden_states": forward(hidden_states)} - - -def run_pp( - rank: int, - world_size: int, - port: int, - num_microbatch: int, - batch_size: int, - num_model_chunk: int, -): - """ - This test is to examine the correctness of interleaved 1F1B, compared with torch. - Be aware it contains some hardcodes. - """ - colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") - - # create model - seed_all(1453) - torch_model = MlpModel().cuda() - pp_model = copy.deepcopy(torch_model).cuda() - - pg_mesh = ProcessGroupMesh(world_size) + # Set not used layer to None + held_layers = self.layers[stage_index[0] : stage_index[1]] + + # fwd end + if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1: + return held_layers(hidden_states) + # fwd start + elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0: + return {"hidden_states": held_layers(data)} + # fwd middle + else: + return {"hidden_states": held_layers(hidden_states)} + + +def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups): + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): + if key_base == key_pp: + if key_base != "params": + assert val_base == val_pp + + +def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +# 1) Test manual v_schedule with multiple microbatch +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_iter_input(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, use_zbv=True, num_model_chunks=num_model_chunk + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk ) # schedule list - mem_f = 34 * 32 + 5 * 4 * 16 - mem_w = -32 * 32 + zbv_schedule = [ + # stage 0 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 1 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), + ], + # stage 2 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3), + ], + # stage 3 + [ + # microbatch 0 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3), + ], + ] + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? + stage_manager=stage_manager, + num_model_chunks=pp_size, + num_microbatch=num_microbatch, + overlap_p2p=False, + ) + + # loss func + def criterion(x, *args, **kwargs): + return (x * x).mean() + + # init model and input + batch_size = 4 + num_layers = 8 + in_dim = out_dim = 8 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + + input_base = [t.clone() for t in data_iter] + model_base = deepcopy(model) + + if rank == 0: + # layer 0 & 7 to chunk 0 on rank0 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 0 or idx == 7: + local_chunk.append(sub_model) + elif rank == 1: + # layer 1 & 6 to chunk 1 on rank1 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 1 or idx == 6: + local_chunk.append(sub_model) + elif rank == 2: + # layer 2 & 5 to chunk 2 on rank2 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 2 or idx == 5: + local_chunk.append(sub_model) + else: + # layer 3 & 4 to chunk 3 on rank3 + local_chunk = torch.nn.ModuleList().to(rank) + for idx, sub_model in enumerate(model.layers): + if idx == 3 or idx == 4: + local_chunk.append(sub_model) + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5)) + + print( + f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" + ) + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=local_chunk, + data_iter=iter(data_iter), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base[0]) + loss_base = criterion(output_base) + loss_base.backward() + optimizer_base.step() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ########################## + # assert weight + ########################## + if rank == 0: + # layer 0 + assert_close(local_chunk[0].weight, model_base.layers[0].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(local_chunk[1].weight, model_base.layers[7].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(local_chunk[0].weight, model_base.layers[1].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(local_chunk[1].weight, model_base.layers[6].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) + if rank == 2: + # layer 2 + assert_close(local_chunk[0].weight, model_base.layers[2].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(local_chunk[1].weight, model_base.layers[5].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) + if rank == 3: + # layer 3 + assert_close(local_chunk[0].weight, model_base.layers[3].weight) + assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(local_chunk[1].weight, model_base.layers[4].weight) + assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) + + +# 2) add optimizer base 1) +@parameterize( + "test_config", + [ + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + { + "batch_size": 8, + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 8, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunk": 2, + }, + ], +) +def run_fwd_bwd_vschedule_with_optim(test_config): + # init dist + rank = dist.get_rank() + pp_size = test_config["pp_size"] + pg_mesh = ProcessGroupMesh(pp_size) + num_microbatch = test_config["num_microbatches"] + num_model_chunk = test_config["num_model_chunk"] + # stage_manager + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True + ) + + h, a, s = 4096, 32, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h mem_b = -mem_w - mem_f - scheduler_nodes = PipelineGraph( - n_stage=4, - n_micro=12, - f_cost=1000, - b_cost=1000, - w_cost=1000, + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatch, + f_cost=1, + b_cost=1, + w_cost=1, c_cost=1, f_mem=mem_f, b_mem=mem_b, w_mem=mem_w, - ).get_v_schedule() - schedule = ZeroBubbleVPipeScheduler( + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + scheduler = ZeroBubbleVPipeScheduler( + schedule=zbv_schedule, # hint: send whole schedule or local schedule only ? stage_manager=stage_manager, - schedule=scheduler_nodes, num_model_chunks=num_model_chunk, num_microbatch=num_microbatch, + overlap_p2p=False, ) - sharded_model = torch.nn.ModuleList() - for idx, sub_model in enumerate(pp_model.layers): - if idx == rank or (NUM_LAYER - idx - 1) == rank: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)), - sub_model._forward, - ) - sharded_model.append(sub_model.cuda()) - assert ( - len(sharded_model) == num_model_chunk - ), f"{len(sharded_model)}, {num_model_chunk}, num_model_chunk is not correct" - - # create optimizer - torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) - pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5)) - - # create data - seed_all(115) - input_list = [torch.rand(batch_size, DIM).cuda()] - dist.all_reduce(input_list[0]) - + # init loss func def criterion(x, *args, **kwargs): + x = x["hidden_states"] return (x * x).mean() - # forward and backward - torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output) - torch_loss.backward() + def criterion_base(x, *args, **kwargs): + return (x * x).mean() - pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) + # init model and input + batch_size = test_config["batch_size"] + num_layers = 8 + assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" + in_dim = out_dim = 1024 + before_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") + model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) + data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} + input_base = {k: v.clone() for k, v in data_iter.items()} + model_base = deepcopy(model) + model_pp = deepcopy(model) + layers_per_stage = stage_manager.distribute_layers(len(model.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) - # check loss - if stage_manager.is_first_stage(ignore_chunk=True): - assert_close(torch_loss, pp_ret["loss"]) + model_pp._forward = model_pp.forward - # check gradients - for i in range(num_model_chunk): - if i == 0: - idx = rank - else: - idx = world_size * 2 - rank - 1 - assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) - assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) - - # step - torch_optimizer.step() - pp_optimizer.step() - pp_optimizer.zero_grad() - - # check updated param - for i in range(num_model_chunk): - if i == 0: - idx = rank - else: - idx = world_size * 2 - rank - 1 - assert_close(torch_model.layers[idx].weight, sharded_model[i].weight) - assert_close(torch_model.layers[idx].bias, sharded_model[i].bias) + model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager) + + # init optimizer + optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) + + after_init_memory = torch.cuda.memory_allocated() / 1024**3 + print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") + + torch.cuda.synchronize() + result = scheduler.forward_backward_step( + model_chunk=model_pp, + data_iter=iter([data_iter]), + criterion=criterion, + optimizer=optimizer_pp, + return_loss=True, + return_outputs=True, + ) + + optimizer_pp.step() - # forward only - with torch.no_grad(): - torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output) + after_pp_step_memory = torch.cuda.memory_allocated() / 1024**3 - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True + # assert memory + if rank != 0: + # w.grad: hid_dim * hid_dim * microbatch * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + # output: hid_dim * hid_dim * microbatch * 4(fp32) / 1024**3 + # optim: state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + print( + f" num_microbatch {num_microbatch} rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 * batch_size / 1024**3)}" + ) + assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 * batch_size / 1024**3) + else: + # rank0 will also hold output; + print( + f" num_microbatch {num_microbatch} rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + ) + assert round((after_pp_step_memory - after_init_memory), 5) <= round( + (in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 ) - if stage_manager.is_first_stage(ignore_chunk=True): - assert_close(torch_loss, pp_ret["loss"]) - for layer in sharded_model: - if layer.weight.grad is None: - assert layer.weight.grad is None and layer.bias.grad is None - else: - assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad)) - torch.cuda.empty_cache() + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + # output_base = model_base(input_base["data"]) + output_base = model_base.forward(data=input_base["data"]) + loss_base = criterion_base(output_base) + loss_base.backward() + optimizer_base.step() + + ########################## + # assert loss & output + ########################## + # only chunk 1 stage 0 hold loss and output + if rank == 0: + assert_close(result["loss"], loss_base) + assert_close(result["outputs"]["hidden_states"], output_base) + + # ########################## + # # assert weight & optim state + # ########################## + optim_base_state = optimizer_base.state_dict()["state"] + optim_pp_state = optimizer_pp.state_dict()["state"] + optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] + optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] + + if rank == 0: + # layer 0 + assert_close(model_pp.layers[0].weight, model_base.layers[0].weight) + assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad) + assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"]) + # layer 7 + assert_close(model_pp.layers[7].weight, model_base.layers[7].weight) + assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad) + assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"]) + if rank == 1: + # layer 1 + assert_close(model_pp.layers[1].weight, model_base.layers[1].weight) + assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad) + assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"]) + # layer 6 + assert_close(model_pp.layers[6].weight, model_base.layers[6].weight) + assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad) + assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"]) + if rank == 2: + # layer 2 + assert_close(model_pp.layers[2].weight, model_base.layers[2].weight) + assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad) + assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"]) + # layer 5 + assert_close(model_pp.layers[5].weight, model_base.layers[5].weight) + assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad) + assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"]) + if rank == 3: + # layer 3 + assert_close(model_pp.layers[3].weight, model_base.layers[3].weight) + assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad) + assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"]) + # layer 4 + assert_close(model_pp.layers[4].weight, model_base.layers[4].weight) + assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad) + assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"]) + + # assert optim param_groups + assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) + + +# TODO:4) support Hybrid base 3) +def run_with_hybridplugin(test_config): + pass + + +# TODO:5) support MoEHybrid base 3) +@parameterize( + "test_config", + [ + { + "pp_style": "zbv", + "tp_size": 1, + "ep_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunks": 2, + }, + ], +) +def run_with_moehybridplugin(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + # test_config["use_lazy_init"] = False + test_config["initial_scale"] = 2**16 + model_list = [ + "transformers_bert", + ] + clear_layout_converter() + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name in model_list: + # base param + model = model_fn() + data = data_gen_fn() + print(f"data {data}") + criterion = loss_fn + optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5) + + output = model(**data) + loss = criterion(output) + loss.backward() + optimizer.step() + print(f"output {output}") + + # # pp param + # model_pp = deepcopy(model) + # data_pp = deepcopy(data) + # optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) + + # # init pipeline graph + # h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024 + # mem_f = 34 * h + 5 * a * s + # mem_w = -32 * h + # mem_b = -mem_w - mem_f + # graph = PipelineGraph( + # n_stage=test_config["pp_size"], + # n_micro=test_config["num_microbatches"], + # f_cost=1, + # b_cost=1, + # w_cost=1, + # c_cost=1, + # f_mem=mem_f, + # b_mem=mem_b, + # w_mem=mem_w, + # # max_mem=mem_f * (p * 2 + m_offset), + # ) + + # zbv_schedule = graph.get_v_schedule() + + # test_config["scheduler_nodes"] = zbv_schedule + # plugin = MoeHybridParallelPlugin( + # **test_config + # ) + # model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure( + # model = model_pp, + # optimizer = optimizer_pp, + # criterion = criterion, + # dataloader = data_pp, + # ) + + # output_pp = plugin.execute_pipeline( + # data_iter=iter(data), + # model=model, + # criterion=criterion, + # optimizer=optimizer, + # return_loss = True, + # return_outputs = True, + # ) + + +# TODO:6) support booster & Hybrid base 4) + + +# TODO:7) support booster & MoEHybrid base 4) +@parameterize( + "test_config", + [ + { + "pp_style": "zbv", + "tp_size": 1, + "ep_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunks": 2, + }, + ], +) +def run_with_booster_moehybridplugin(test_config): + pass + + +def run_dist(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + # run_fwd_bwd_iter_input() + run_fwd_bwd_vschedule_with_optim() + # run_with_moehybridplugin() + # run_with_booster_moehybridplugin() @pytest.mark.dist -@pytest.mark.parametrize("num_microbatch", [12]) -@pytest.mark.parametrize("batch_size", [24]) -@pytest.mark.parametrize("num_model_chunk", [2]) @rerun_if_address_is_in_use() -def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): - assert NUM_LAYER % num_model_chunk == 0 +def test_pp(): spawn( - run_pp, - nprocs=NUM_LAYER // num_model_chunk, - num_microbatch=num_microbatch, - batch_size=batch_size, - num_model_chunk=num_model_chunk, + run_dist, + nprocs=4, ) if __name__ == "__main__": - test_pp(num_microbatch=4, batch_size=12, num_model_chunk=2) + test_pp() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 9f67ecbea687..3f279a116cd7 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -162,135 +162,135 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # Double Ring Attention - { - "tp_size": 1, - "pp_size": 1, - "sp_size": 4, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - "inner_ring_size": 2, - }, - # Ring Attention + PP - { - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - # Ring Attention + TP - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring_attn", - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + TP - "tp_size": 2, - "pp_size": 1, - "sp_size": 2, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { # Ulysess + PP - "tp_size": 1, - "pp_size": 2, - "sp_size": 2, - "num_microbatches": 2, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 0, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 1, - "sp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 4, - "use_lazy_init": False, - "precision": "fp32", - "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - }, - { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 2, - "precision": "fp16", - "initial_scale": 1, - }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, + # # Double Ring Attention + # { + # "tp_size": 1, + # "pp_size": 1, + # "sp_size": 4, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # "inner_ring_size": 2, + # }, + # # Ring Attention + PP + # { + # "tp_size": 1, + # "pp_size": 2, + # "sp_size": 2, + # "num_microbatches": 2, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # # Ring Attention + TP + # { + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring_attn", + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { # Ulysess + TP + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 2, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { # Ulysess + PP + # "tp_size": 1, + # "pp_size": 2, + # "sp_size": 2, + # "num_microbatches": 2, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "all_to_all", + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 0, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 4, + # "pp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "split_gather", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "sp_size": 1, + # "num_microbatches": 1, + # "enable_sequence_parallelism": True, + # "sequence_parallelism_mode": "ring", + # "enable_flash_attention": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 2, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "precision": "fp16", + # "initial_scale": 1, + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 4, + # "use_lazy_init": False, + # "precision": "fp32", + # "enable_gradient_checkpointing": True, + # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + # }, + # { + # "tp_size": 2, + # "pp_size": 1, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 2, + # "precision": "fp16", + # "initial_scale": 1, + # }, + # { + # "tp_size": 1, + # "pp_size": 2, + # "num_microbatches": 2, + # "enable_all_optimization": True, + # "use_lazy_init": True, + # "zero_stage": 1, + # "precision": "fp16", + # "initial_scale": 1, + # }, { "tp_size": 2, "pp_size": 2, @@ -410,11 +410,11 @@ def test_llama(): spawn(check_llama, 4) -@pytest.mark.largedist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama_3d(): - spawn(check_llama_3d, 8) +# @pytest.mark.largedist +# @rerun_if_address_is_in_use() +# @clear_cache_before_run() +# def test_llama_3d(): +# spawn(check_llama_3d, 8) if __name__ == "__main__": From 83d07668babb9d916751d264087f1ac7e8565a6c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 27 Sep 2024 03:40:06 +0000 Subject: [PATCH 42/47] fix --- .../test_model/test_shard_llama.py | 268 +++++++++--------- 1 file changed, 134 insertions(+), 134 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3f279a116cd7..9f67ecbea687 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -162,135 +162,135 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # # Double Ring Attention - # { - # "tp_size": 1, - # "pp_size": 1, - # "sp_size": 4, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # "inner_ring_size": 2, - # }, - # # Ring Attention + PP - # { - # "tp_size": 1, - # "pp_size": 2, - # "sp_size": 2, - # "num_microbatches": 2, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # # Ring Attention + TP - # { - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring_attn", - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { # Ulysess + TP - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 2, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { # Ulysess + PP - # "tp_size": 1, - # "pp_size": 2, - # "sp_size": 2, - # "num_microbatches": 2, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "all_to_all", - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 0, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 4, - # "pp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "split_gather", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "sp_size": 1, - # "num_microbatches": 1, - # "enable_sequence_parallelism": True, - # "sequence_parallelism_mode": "ring", - # "enable_flash_attention": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 2, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "precision": "fp16", - # "initial_scale": 1, - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 4, - # "use_lazy_init": False, - # "precision": "fp32", - # "enable_gradient_checkpointing": True, - # "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), - # }, - # { - # "tp_size": 2, - # "pp_size": 1, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 2, - # "precision": "fp16", - # "initial_scale": 1, - # }, - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "enable_all_optimization": True, - # "use_lazy_init": True, - # "zero_stage": 1, - # "precision": "fp16", - # "initial_scale": 1, - # }, + # Double Ring Attention + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 4, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + "inner_ring_size": 2, + }, + # Ring Attention + PP + { + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + # Ring Attention + TP + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + TP + "tp_size": 2, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { # Ulysess + PP + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 0, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "sp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -410,11 +410,11 @@ def test_llama(): spawn(check_llama, 4) -# @pytest.mark.largedist -# @rerun_if_address_is_in_use() -# @clear_cache_before_run() -# def test_llama_3d(): -# spawn(check_llama_3d, 8) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) if __name__ == "__main__": From 9b3c266b5b199e6d6906fd0af4fb7dc701de93c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Sep 2024 03:45:03 +0000 Subject: [PATCH 43/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/pipeline/stage_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 5e6177ce8196..5cc32114daff 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -96,7 +96,7 @@ def get_stage_index( ] ) return stage_indices - + for model_chunk in range(num_model_chunks): start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] From f28893019d7e2465752af9b5379b47bfd289b405 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 27 Sep 2024 03:52:59 +0000 Subject: [PATCH 44/47] fix --- tests/test_shardformer/test_model/test_shard_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 9f67ecbea687..ce886d28baf3 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -367,7 +367,7 @@ def run_llama_test(test_config): "num_microbatches": 4, "enable_all_optimization": False, "precision": "fp16", - "zero_stage": 1, + "zero_stage": 0, "initial_scale": 1, "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig( From 93557f5eebbbf8a80671f418561960b9486e3546 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 27 Sep 2024 03:55:21 +0000 Subject: [PATCH 45/47] fix --- tests/test_shardformer/test_model/test_shard_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ce886d28baf3..d925687cd875 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -299,7 +299,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 4, "enable_all_optimization": False, "precision": "fp16", - "zero_stage": 1, + "zero_stage": 0, "initial_scale": 1, "enable_gradient_checkpointing": True, "parallel_output": False, @@ -367,7 +367,7 @@ def run_llama_test(test_config): "num_microbatches": 4, "enable_all_optimization": False, "precision": "fp16", - "zero_stage": 0, + "zero_stage": 1, "initial_scale": 1, "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig( From 90a82e2309d2872d10b35ae2bc3dc08431f486da Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 27 Sep 2024 04:14:04 +0000 Subject: [PATCH 46/47] fix --- .../amp/naive_amp/mixed_precision_optimizer.py | 7 ++++++- colossalai/pipeline/schedule/zero_bubble_pp.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 121c92011cc5..8fb56aee4fce 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -91,7 +91,12 @@ def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) - tensor.backward(grad, inputs=inputs, retain_graph=retain_graph) + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def zero_grad(self, *args, **kwargs): for p in self.working_to_master_map.keys(): diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 5c25c5bfaa80..8fc1c08c1de1 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -499,13 +499,16 @@ def backward_b_step( input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] - - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad_, - inputs=input_obj_, - retain_graph=True, - ) + try: + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=input_obj_, + retain_graph=True, + ) + except Exception as e: + print(f"{output_obj_=}") + raise e # Format output_obj_grad input_obj_grad = {} From 74035808f0b70e52e2cacf8e048d68b53d362335 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 27 Sep 2024 04:17:49 +0000 Subject: [PATCH 47/47] fix --- colossalai/pipeline/schedule/zero_bubble_pp.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 8fc1c08c1de1..5c25c5bfaa80 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -499,16 +499,13 @@ def backward_b_step( input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] - try: - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad_, - inputs=input_obj_, - retain_graph=True, - ) - except Exception as e: - print(f"{output_obj_=}") - raise e + + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=input_obj_, + retain_graph=True, + ) # Format output_obj_grad input_obj_grad = {}