From 89ca62979c57890823008b6cb4dd1a492beada26 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 29 Dec 2023 16:27:01 +0800 Subject: [PATCH 1/6] fix: add fallback order option and update 1f1b --- .../booster/plugin/hybrid_parallel_plugin.py | 9 ----- colossalai/pipeline/p2p.py | 37 ++++++++++++++----- colossalai/pipeline/schedule/one_f_one_b.py | 20 +++++++--- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ea74f75f43c8..c52de0ba7a2a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,4 @@ import ctypes -import os import random from contextlib import contextmanager from functools import partial @@ -23,7 +22,6 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.logging import get_dist_logger from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -984,13 +982,6 @@ def __init__( self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - if os.getenv("NCCL_BUFFSIZE") is None: - logger = get_dist_logger() - logger.warning( - "Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen." - ) - os.environ["NCCL_BUFFSIZE"] = "134217728" - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index cdb7a6a1e539..d32ff501f033 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -344,6 +344,7 @@ def _communicate( recv_group: Optional[ProcessGroup] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: """ Send and receive object from send_dst and recv_src respectively @@ -368,8 +369,14 @@ def _communicate( # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): - _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) - return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + assert send_prior_fallback is not None, "Priority must be set if fallback happens" + if send_prior_fallback: + _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) + return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + else: + recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) + return recv_data # NOTE: only the following 5 cases are valid: # 1. send() [needs extra metadata] and no recv() @@ -437,7 +444,7 @@ def _communicate( raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) -def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None: +def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None: """send anything to dst rank Args: @@ -447,10 +454,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta Returns: None """ - _communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata) + _communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs) -def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any: +def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any: """recv anything from src Args: @@ -459,7 +466,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optiona Returns: Any: Object received from src. """ - return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv) + return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs) def _p2p_comm( @@ -557,7 +564,10 @@ def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[ prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() input_tensor = _recv_object( - prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv + prev_rank, + cur_rank, + self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), + metadata_recv=metadata_recv, ) return input_tensor @@ -575,7 +585,10 @@ def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() output_tensor_grad = _recv_object( - next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv + next_rank, + cur_rank, + self.stage_manager.get_p2p_process_group(next_rank, cur_rank), + metadata_recv=metadata_recv, ) return output_tensor_grad @@ -595,7 +608,7 @@ def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank), - send_metadata, + send_metadata=send_metadata, ) def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None: @@ -613,7 +626,7 @@ def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank), - send_metadata, + send_metadata=send_metadata, ) def send_forward_recv_backward( @@ -622,6 +635,7 @@ def send_forward_recv_backward( next_rank: Optional[int] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline @@ -642,6 +656,7 @@ def send_forward_recv_backward( recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, + send_prior_fallback=send_prior_fallback, ) def send_backward_recv_forward( @@ -650,6 +665,7 @@ def send_backward_recv_forward( prev_rank: Optional[int] = None, send_metadata: bool = True, metadata_recv: Optional[P2PMetadata] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline @@ -670,6 +686,7 @@ def send_backward_recv_forward( recv_group=group, send_metadata=send_metadata, metadata_recv=metadata_recv, + send_prior_fallback=send_prior_fallback, ) def p2p_communicate( diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 6b2436d545e7..0c3bedb318eb 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -94,7 +94,7 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.send_metadata_backward = True self.metadata_recv_forward = None self.metadata_recv_backward = None - + self.last_batch_size = self.batch_size def load_micro_batch(self) -> Any: @@ -166,7 +166,9 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) self.send_metadata_backward = not self.enable_metadata_cache - def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: + def send_forward_recv_backward( + self, output_object: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None + ) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. For 1F1B. @@ -180,6 +182,7 @@ def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) next_rank, send_metadata=self.send_metadata_forward, metadata_recv=self.metadata_recv_backward, + send_prior_fallback=send_prior_fallback, ) self.send_metadata_forward = not self.enable_metadata_cache if self.enable_metadata_cache and self.metadata_recv_backward is None: @@ -187,7 +190,9 @@ def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) return output_tensor_grad - def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: + def send_backward_recv_forward( + self, output_object: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None + ) -> Any: """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. For 1F1B. @@ -201,6 +206,7 @@ def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) prev_rank, send_metadata=self.send_metadata_backward, metadata_recv=self.metadata_recv_forward, + send_prior_fallback=send_prior_fallback, ) self.send_metadata_backward = not self.enable_metadata_cache if self.enable_metadata_cache and self.metadata_recv_forward is None: @@ -365,7 +371,9 @@ def run_forward_backward( last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - output_obj_grad = self.send_forward_recv_backward(output_obj) + output_obj_grad = self.send_forward_recv_backward( + output_obj, send_prior_fallback=self.stage_manager.stage % 2 == 0 + ) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) @@ -379,7 +387,9 @@ def run_forward_backward( if last_iteration: self.send_backward(input_obj_grad) else: - input_obj = self.send_backward_recv_forward(input_obj_grad) + input_obj = self.send_backward_recv_forward( + input_obj_grad, send_prior_fallback=self.stage_manager.stage % 2 == 0 + ) # Run cooldown backward passes. for i in range(num_warmup_microbatches): From 2e6d5262592d1d8d1988507537715366d361680d Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Sat, 30 Dec 2023 03:28:31 +0800 Subject: [PATCH 2/6] fix: fix dead lock comm in interleaved pp --- .../pipeline/schedule/interleaved_pp.py | 211 +++++++++++++----- colossalai/pipeline/schedule/one_f_one_b.py | 4 + 2 files changed, 162 insertions(+), 53 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 23b3f4e6c60d..9134458ba990 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -108,7 +108,8 @@ def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: Returns: int: The model chunk idx of the input microbatch_id """ - microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) + assert microbatch_id < self.num_microbatch * self.num_model_chunks + microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not is_forward: model_chunk_id = self.num_model_chunks - model_chunk_id - 1 @@ -181,38 +182,92 @@ def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = self.send_metadata_backward = not self.enable_metadata_cache def send_forward_recv_backward( - self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None + self, + model_chunk_id_send: int, + model_chunk_id_recv: int, + output_object: Any, + next_rank: Optional[int] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: - with self.stage_manager.switch_model_chunk_id(model_chunk_id): - if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.send_forward_recv_backward( - output_object, - next_rank, - send_metadata=self.send_metadata_forward, - metadata_recv=self.metadata_recv_backward, - ) - self.send_metadata_forward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): + send_data = not self.stage_manager.is_last_stage() + with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): + recv_data = not self.stage_manager.is_last_stage() + + if send_data and recv_data: + output_tensor_grad = self.comm.send_forward_recv_backward( + output_object, + next_rank, + send_metadata=self.send_metadata_forward, + metadata_recv=self.metadata_recv_backward, + send_prior_fallback=send_prior_fallback, + ) + self.send_metadata_forward = not self.enable_metadata_cache + if self.enable_metadata_cache and self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + return output_tensor_grad - return output_tensor_grad + elif send_data: + return self.send_forward(model_chunk_id_send, output_object) + + elif recv_data: + return self.recv_backward(model_chunk_id_recv) def send_backward_recv_forward( - self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None + self, + model_chunk_id_send: int, + model_chunk_id_recv: int, + output_object: Any, + prev_rank: Optional[int] = None, + send_prior_fallback: Optional[bool] = None, ) -> Any: - with self.stage_manager.switch_model_chunk_id(model_chunk_id): - if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.send_backward_recv_forward( - output_object, - prev_rank, - send_metadata=self.send_metadata_backward, - metadata_recv=self.metadata_recv_forward, - ) - self.send_metadata_backward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + with self.stage_manager.switch_model_chunk_id(model_chunk_id_send): + send_data = not self.stage_manager.is_first_stage() + with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): + recv_data = not self.stage_manager.is_first_stage() + + if send_data and recv_data: + input_tensor = self.comm.send_backward_recv_forward( + output_object, + prev_rank, + send_metadata=self.send_metadata_backward, + metadata_recv=self.metadata_recv_forward, + send_prior_fallback=send_prior_fallback, + ) + self.send_metadata_backward = not self.enable_metadata_cache + if self.enable_metadata_cache and self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + return input_tensor + + elif send_data: + return self.send_backward(model_chunk_id_send, output_object) + + elif recv_data: + return self.recv_forward(model_chunk_id_recv) + + def send_forward_recv_forward( + self, model_chunk_id_send: int, model_chunk_id_recv: int, output_obj: Any, send_prior: bool + ): + if send_prior: + self.send_forward(model_chunk_id_send, output_obj) + input_obj = self.recv_forward(model_chunk_id_recv) + else: + input_obj = self.recv_forward(model_chunk_id_recv) + self.send_forward(model_chunk_id_send, output_obj) - return input_tensor + return input_obj + + def send_backward_recv_backward( + self, model_chunk_id_send: int, model_chunk_id_recv: int, output_obj: Any, send_prior: bool + ): + if send_prior: + self.send_backward(model_chunk_id_send, output_obj) + input_obj = self.recv_backward(model_chunk_id_recv) + else: + input_obj = self.recv_backward(model_chunk_id_recv) + self.send_backward(model_chunk_id_send, output_obj) + + return input_obj def forward_step( self, @@ -321,12 +376,23 @@ def run_forward_only( if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) - # Run warmup forward passes. + model_chunk_id = self.get_model_chunk_id(0, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) + for i in range(self.num_microbatch * self.num_model_chunks): + last_iteration = i == self.num_microbatch * self.num_model_chunks - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - self.send_forward(model_chunk_id, output_obj) + + if not last_iteration: + input_obj = self.send_forward_recv_forward( + model_chunk_id_send=model_chunk_id, + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), + output_obj=output_obj, + send_prior=self.stage_manager.stage % 2 == 0, + ) + else: + self.send_forward(model_chunk_id, output_obj) if outputs is not None: outputs = merge_batch(outputs) @@ -364,54 +430,93 @@ def run_forward_backward( if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_current_device()) + model_chunk_id = self.get_model_chunk_id(0, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) + # Run warmup forward passes. for i in range(num_warmup_microbatch): + last_iteration = i == num_warmup_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) - self.send_forward(model_chunk_id, output_obj) + + if last_iteration and num_microbatch_remaining == 0: + self.send_forward(model_chunk_id, output_obj) + else: + input_obj = self.send_forward_recv_forward( + model_chunk_id_send=model_chunk_id, + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), + output_obj=output_obj, + send_prior=self.stage_manager.stage % 2 == 0, + ) if num_microbatch_remaining > 0: - model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(0, is_forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) # Run 1F1B in steady state. for i in range(num_microbatch_remaining): - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) last_iteration = i == num_microbatch_remaining - 1 + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - self.send_forward(model_chunk_id, output_obj) # Add input_obj and output_obj to end of list. input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - output_obj_grad = self.recv_backward(model_chunk_id) - - # Pop output_obj and output_obj from the start of the list for - # the backward pass. - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) + # Pop output_obj and output_obj from the start of the list for the backward pass. + _input_obj = input_objs[model_chunk_id].pop(0) + _output_obj = output_objs[model_chunk_id].pop(0) + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + + # NOTE: perform 2x communication for forward and backward + if last_iteration and num_microbatch == num_microbatch_remaining: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + self.send_forward(model_chunk_id, output_obj) + else: + output_obj_grad = self.send_forward_recv_backward( + model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True), + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), + output_object=output_obj, + send_prior_fallback=self.stage_manager.stage % 2 == 0, + ) - # backward - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.send_backward(model_chunk_id, input_obj_grad) + if last_iteration: + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + else: + input_obj = self.send_backward_recv_forward( + model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), + model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True), + output_object=input_obj_grad, + send_prior_fallback=self.stage_manager.stage % 2 == 0, + ) - if not last_iteration: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + if num_microbatch_remaining == 0: + model_chunk_id = self.get_model_chunk_id(0, is_forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) # Run cooldown backward passes. for i in range(num_microbatch_remaining, num_microbatch): + last_iteration = i == num_microbatch - 1 model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - input_obj = input_objs[model_chunk_id].pop(0) - output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = self.recv_backward(model_chunk_id) - input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.send_backward(model_chunk_id, input_obj_grad) + _input_obj = input_objs[model_chunk_id].pop(0) + _output_obj = output_objs[model_chunk_id].pop(0) + # output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) + + if not last_iteration: + self.send_backward_recv_backward( + model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), + output_obj=input_obj_grad, + send_prior=self.stage_manager.stage % 2 == 0, + ) + else: + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + self.send_backward(model_chunk_id, input_obj_grad) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 0c3bedb318eb..799a7cf4a65b 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -177,6 +177,8 @@ def send_forward_recv_backward( next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): + if not self.send_metadata_forward and self.metadata_recv_backward is not None: + send_prior_fallback = None # must not fallback output_tensor_grad = self.comm.send_forward_recv_backward( output_object, next_rank, @@ -201,6 +203,8 @@ def send_backward_recv_forward( prev_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_first_stage(): + if not self.send_metadata_backward and self.metadata_recv_forward is not None: + send_prior_fallback = None # must not fallback input_tensor = self.comm.send_backward_recv_forward( output_object, prev_rank, From 6163b383a8f48c254e19526d10483605af42a578 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Sat, 30 Dec 2023 03:33:27 +0800 Subject: [PATCH 3/6] test: modify p2p test --- tests/test_pipeline/test_p2p_communication.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 40b6ac8eb6ff..1c859fd93bf8 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -1,5 +1,3 @@ -import warnings - import pytest import torch import torch.distributed as dist @@ -33,7 +31,7 @@ def check_p2p_communication(): for obj in data: p2p.send_forward(obj) for i in range(len(data)): - recv_obj = p2p.send_forward_recv_backward(data[i]) + recv_obj = p2p.send_forward_recv_backward(data[i], send_prior_fallback=False) assert recv_obj == data[-(i + 1)] elif rank == 1: for obj in data: @@ -48,7 +46,7 @@ def check_p2p_communication(): for obj in data: p2p.send_backward(obj) for i in range(len(data)): - recv_obj = p2p.send_backward_recv_forward(data[i]) + recv_obj = p2p.send_backward_recv_forward(data[i], send_prior_fallback=True) assert recv_obj == data[-(i + 1)] elif rank == 0: for obj in data: @@ -59,7 +57,6 @@ def check_p2p_communication(): p2p.send_forward(data[-(i + 1)]) assert recv_obj == data[i] - warnings.filterwarnings("error") tensor_metadata = TensorMetadata( key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad ) From ff152fc0aef8072afed37b43aed4609c6666f988 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Sat, 30 Dec 2023 05:07:53 +0800 Subject: [PATCH 4/6] style: polish code --- colossalai/pipeline/schedule/one_f_one_b.py | 72 ++++++++++----------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 799a7cf4a65b..be60dcc748ab 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -54,10 +54,10 @@ def __init__( # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -90,10 +90,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) # NOTE: disable metadata cache when batch size changes (not valid anymore) if self.batch_size != self.last_batch_size: self.enable_metadata_cache = False - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None self.last_batch_size = self.batch_size @@ -119,9 +119,9 @@ def recv_forward(self, prev_rank: int = None) -> Any: Any: The input tensor or input tensor list. """ if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) return input_tensor @@ -136,13 +136,13 @@ def recv_backward(self, next_rank: int = None) -> Any: Any: The input gradient tensor or gradient tensor list. """ if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad - def send_forward(self, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, output_tensor: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. For 1F1B. @@ -151,10 +151,10 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None: next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) - self.send_metadata_forward = not self.enable_metadata_cache + self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) + self.send_tensor_metadata = not self.enable_metadata_cache - def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. @@ -163,11 +163,11 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) - self.send_metadata_backward = not self.enable_metadata_cache + self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) + self.send_grad_metadata = not self.enable_metadata_cache def send_forward_recv_backward( - self, output_object: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None + self, output_tensor: Any, next_rank: int = None, send_prior_fallback: Optional[bool] = None ) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. For 1F1B. @@ -177,23 +177,23 @@ def send_forward_recv_backward( next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): - if not self.send_metadata_forward and self.metadata_recv_backward is not None: + if not self.send_tensor_metadata and self.grad_metadata_recv is not None: send_prior_fallback = None # must not fallback output_tensor_grad = self.comm.send_forward_recv_backward( - output_object, + output_tensor, next_rank, - send_metadata=self.send_metadata_forward, - metadata_recv=self.metadata_recv_backward, + send_metadata=self.send_tensor_metadata, + metadata_recv=self.grad_metadata_recv, send_prior_fallback=send_prior_fallback, ) - self.send_metadata_forward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + self.send_tensor_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad def send_backward_recv_forward( - self, output_object: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None + self, input_tensor_grad: Any, prev_rank: int = None, send_prior_fallback: Optional[bool] = None ) -> Any: """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. For 1F1B. @@ -203,18 +203,18 @@ def send_backward_recv_forward( prev_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_first_stage(): - if not self.send_metadata_backward and self.metadata_recv_forward is not None: + if not self.send_grad_metadata and self.tensor_metadata_recv is not None: send_prior_fallback = None # must not fallback input_tensor = self.comm.send_backward_recv_forward( - output_object, + input_tensor_grad, prev_rank, - send_metadata=self.send_metadata_backward, - metadata_recv=self.metadata_recv_forward, + send_metadata=self.send_grad_metadata, + metadata_recv=self.tensor_metadata_recv, send_prior_fallback=send_prior_fallback, ) - self.send_metadata_backward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + self.send_grad_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) return input_tensor From ba48d2acc0ead1b09e23b7434b1696d8f57e5812 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Sat, 30 Dec 2023 13:21:26 +0800 Subject: [PATCH 5/6] fix: fix interleaved pp comm --- .../pipeline/schedule/interleaved_pp.py | 171 ++++++++++-------- 1 file changed, 91 insertions(+), 80 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 9134458ba990..4c65470fab19 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -41,10 +41,10 @@ def __init__( # P2PMeta cache self.enable_metadata_cache = enable_metadata_cache - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -77,10 +77,10 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) # NOTE: disable metadata cache when batch size changes (not valid anymore) if self.batch_size != self.last_batch_size: self.enable_metadata_cache = False - self.send_metadata_forward = True - self.send_metadata_backward = True - self.metadata_recv_forward = None - self.metadata_recv_backward = None + self.send_tensor_metadata = True + self.send_grad_metadata = True + self.tensor_metadata_recv = None + self.grad_metadata_recv = None self.last_batch_size = self.batch_size @@ -128,9 +128,9 @@ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.tensor_metadata_recv) + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) return input_tensor @@ -147,13 +147,13 @@ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad - def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. For interleaved 1F1B. @@ -164,10 +164,10 @@ def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) - self.send_metadata_forward = not self.enable_metadata_cache + self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) + self.send_tensor_metadata = not self.enable_metadata_cache - def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For interleaved 1F1B. @@ -178,14 +178,14 @@ def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) - self.send_metadata_backward = not self.enable_metadata_cache + self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) + self.send_grad_metadata = not self.enable_metadata_cache def send_forward_recv_backward( self, model_chunk_id_send: int, model_chunk_id_recv: int, - output_object: Any, + output_tensor: Any, next_rank: Optional[int] = None, send_prior_fallback: Optional[bool] = None, ) -> Any: @@ -195,29 +195,29 @@ def send_forward_recv_backward( recv_data = not self.stage_manager.is_last_stage() if send_data and recv_data: + if not self.send_forward_recv_backward and self.grad_metadata_recv is not None: + send_prior_fallback = None # must not fallback output_tensor_grad = self.comm.send_forward_recv_backward( - output_object, + output_tensor, next_rank, - send_metadata=self.send_metadata_forward, - metadata_recv=self.metadata_recv_backward, + send_metadata=self.send_tensor_metadata, + metadata_recv=self.grad_metadata_recv, send_prior_fallback=send_prior_fallback, ) - self.send_metadata_forward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + self.send_tensor_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.grad_metadata_recv is None: + self.grad_metadata_recv = create_fast_send_metadata(output_tensor_grad) return output_tensor_grad - elif send_data: - return self.send_forward(model_chunk_id_send, output_object) - - elif recv_data: - return self.recv_backward(model_chunk_id_recv) + # send only or recv only + self.send_forward(model_chunk_id_send, output_tensor) + return self.recv_backward(model_chunk_id_recv) def send_backward_recv_forward( self, model_chunk_id_send: int, model_chunk_id_recv: int, - output_object: Any, + input_tensor_grad: Any, prev_rank: Optional[int] = None, send_prior_fallback: Optional[bool] = None, ) -> Any: @@ -227,47 +227,47 @@ def send_backward_recv_forward( recv_data = not self.stage_manager.is_first_stage() if send_data and recv_data: + if not self.send_backward_recv_backward and self.tensor_metadata_recv is not None: + send_prior_fallback = None # must not fallback input_tensor = self.comm.send_backward_recv_forward( - output_object, + input_tensor_grad, prev_rank, - send_metadata=self.send_metadata_backward, - metadata_recv=self.metadata_recv_forward, + send_metadata=self.send_grad_metadata, + metadata_recv=self.tensor_metadata_recv, send_prior_fallback=send_prior_fallback, ) - self.send_metadata_backward = not self.enable_metadata_cache - if self.enable_metadata_cache and self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + self.send_grad_metadata = not self.enable_metadata_cache + if self.enable_metadata_cache and self.tensor_metadata_recv is None: + self.tensor_metadata_recv = create_fast_send_metadata(input_tensor) return input_tensor - elif send_data: - return self.send_backward(model_chunk_id_send, output_object) - - elif recv_data: - return self.recv_forward(model_chunk_id_recv) + # send only or recv only + self.send_backward(model_chunk_id_send, input_tensor_grad) + return self.recv_forward(model_chunk_id_recv) def send_forward_recv_forward( - self, model_chunk_id_send: int, model_chunk_id_recv: int, output_obj: Any, send_prior: bool + self, model_chunk_id_send: int, model_chunk_id_recv: int, output_tensor: Any, send_prior: bool ): if send_prior: - self.send_forward(model_chunk_id_send, output_obj) - input_obj = self.recv_forward(model_chunk_id_recv) + self.send_forward(model_chunk_id_send, output_tensor) + input_tensor = self.recv_forward(model_chunk_id_recv) else: - input_obj = self.recv_forward(model_chunk_id_recv) - self.send_forward(model_chunk_id_send, output_obj) + input_tensor = self.recv_forward(model_chunk_id_recv) + self.send_forward(model_chunk_id_send, output_tensor) - return input_obj + return input_tensor def send_backward_recv_backward( - self, model_chunk_id_send: int, model_chunk_id_recv: int, output_obj: Any, send_prior: bool + self, model_chunk_id_send: int, model_chunk_id_recv: int, input_tensor_grad: Any, send_prior: bool ): if send_prior: - self.send_backward(model_chunk_id_send, output_obj) - input_obj = self.recv_backward(model_chunk_id_recv) + self.send_backward(model_chunk_id_send, input_tensor_grad) + output_tensor_grad = self.recv_backward(model_chunk_id_recv) else: - input_obj = self.recv_backward(model_chunk_id_recv) - self.send_backward(model_chunk_id_send, output_obj) + output_tensor_grad = self.recv_backward(model_chunk_id_recv) + self.send_backward(model_chunk_id_send, input_tensor_grad) - return input_obj + return output_tensor_grad def forward_step( self, @@ -388,7 +388,7 @@ def run_forward_only( input_obj = self.send_forward_recv_forward( model_chunk_id_send=model_chunk_id, model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), - output_obj=output_obj, + output_tensor=output_obj, send_prior=self.stage_manager.stage % 2 == 0, ) else: @@ -447,7 +447,7 @@ def run_forward_backward( input_obj = self.send_forward_recv_forward( model_chunk_id_send=model_chunk_id, model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=True), - output_obj=output_obj, + output_tensor=output_obj, send_prior=self.stage_manager.stage % 2 == 0, ) @@ -472,27 +472,38 @@ def run_forward_backward( input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) # NOTE: perform 2x communication for forward and backward - if last_iteration and num_microbatch == num_microbatch_remaining: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) - self.send_forward(model_chunk_id, output_obj) + def send_forward_recv_backward(): + if last_iteration and num_microbatch == num_microbatch_remaining: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + self.send_forward(model_chunk_id, output_obj) + else: + output_obj_grad = self.send_forward_recv_backward( + model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True), + model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), + output_tensor=output_obj, + send_prior_fallback=self.stage_manager.stage % 2 == 0, + ) + return output_obj_grad + + def send_backward_recv_forward(): + if last_iteration: + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + else: + input_obj = self.send_backward_recv_forward( + model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), + model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True), + input_tensor_grad=input_obj_grad, + send_prior_fallback=self.stage_manager.stage % 2 == 0 and i > 0, + ) + return input_obj + + if self.stage_manager.stage % 2 == 0: + output_obj_grad = send_forward_recv_backward() + input_obj = send_backward_recv_forward() else: - output_obj_grad = self.send_forward_recv_backward( - model_chunk_id_send=self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True), - model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), - output_object=output_obj, - send_prior_fallback=self.stage_manager.stage % 2 == 0, - ) - - if last_iteration: - model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - self.send_backward(model_chunk_id, input_obj_grad) - else: - input_obj = self.send_backward_recv_forward( - model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), - model_chunk_id_recv=self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True), - output_object=input_obj_grad, - send_prior_fallback=self.stage_manager.stage % 2 == 0, - ) + input_obj = send_backward_recv_forward() + output_obj_grad = send_forward_recv_backward() if num_microbatch_remaining == 0: model_chunk_id = self.get_model_chunk_id(0, is_forward=False) @@ -508,10 +519,10 @@ def run_forward_backward( input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) if not last_iteration: - self.send_backward_recv_backward( + output_obj_grad = self.send_backward_recv_backward( model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), - output_obj=input_obj_grad, + input_tensor_grad=input_obj_grad, send_prior=self.stage_manager.stage % 2 == 0, ) else: From 50c3ccdc8b9cf979c88c43a2146a9ae2c2cde046 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Sat, 30 Dec 2023 13:50:41 +0800 Subject: [PATCH 6/6] fix: fix fallback order in cooldown phase --- colossalai/pipeline/schedule/interleaved_pp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 4c65470fab19..aa18a85204c2 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -432,7 +432,6 @@ def run_forward_backward( model_chunk_id = self.get_model_chunk_id(0, is_forward=True) input_obj = self.recv_forward(model_chunk_id) - # Run warmup forward passes. for i in range(num_warmup_microbatch): last_iteration = i == num_warmup_microbatch - 1 @@ -508,7 +507,6 @@ def send_backward_recv_forward(): if num_microbatch_remaining == 0: model_chunk_id = self.get_model_chunk_id(0, is_forward=False) output_obj_grad = self.recv_backward(model_chunk_id) - # Run cooldown backward passes. for i in range(num_microbatch_remaining, num_microbatch): last_iteration = i == num_microbatch - 1 @@ -523,7 +521,7 @@ def send_backward_recv_forward(): model_chunk_id_send=self.get_model_chunk_id(i, is_forward=False), model_chunk_id_recv=self.get_model_chunk_id(i + 1, is_forward=False), input_tensor_grad=input_obj_grad, - send_prior=self.stage_manager.stage % 2 == 0, + send_prior=self.stage_manager.stage % 2 == 0 and i > num_microbatch_remaining, ) else: model_chunk_id = self.get_model_chunk_id(i, is_forward=False)