Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ctypes
import os
import random
from contextlib import contextmanager
from functools import partial
Expand All @@ -23,7 +22,6 @@
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand Down Expand Up @@ -984,13 +982,6 @@ def __init__(
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
if os.getenv("NCCL_BUFFSIZE") is None:
logger = get_dist_logger()
logger.warning(
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
)
os.environ["NCCL_BUFFSIZE"] = "134217728"

assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
Expand Down
37 changes: 27 additions & 10 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _communicate(
recv_group: Optional[ProcessGroup] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""
Send and receive object from send_dst and recv_src respectively
Expand All @@ -368,8 +369,14 @@ def _communicate(
# NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata,
# we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case.
if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None):
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
assert send_prior_fallback is not None, "Priority must be set if fallback happens"
if send_prior_fallback:
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
else:
recv_data = _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv)
_communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata)
return recv_data

# NOTE: only the following 5 cases are valid:
# 1. send() [needs extra metadata] and no recv()
Expand Down Expand Up @@ -437,7 +444,7 @@ def _communicate(
raise ValueError("Unknown data type {}".format(metadata_recv.data_type))


def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None:
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, **kwargs) -> None:
"""send anything to dst rank

Args:
Expand All @@ -447,10 +454,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta
Returns:
None
"""
_communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata)
_communicate(object, send_dst=dst, recv_src=None, send_group=group, **kwargs)


def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any:
def _recv_object(src: int, dst: int, group: ProcessGroup, **kwargs) -> Any:
"""recv anything from src

Args:
Expand All @@ -459,7 +466,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optiona
Returns:
Any: Object received from src.
"""
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv)
return _communicate(None, send_dst=None, recv_src=src, recv_group=group, **kwargs)


def _p2p_comm(
Expand Down Expand Up @@ -557,7 +564,10 @@ def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(
prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv
prev_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank),
metadata_recv=metadata_recv,
)

return input_tensor
Expand All @@ -575,7 +585,10 @@ def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(
next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv
next_rank,
cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank),
metadata_recv=metadata_recv,
)

return output_tensor_grad
Expand All @@ -595,7 +608,7 @@ def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send
cur_rank,
next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank),
send_metadata,
send_metadata=send_metadata,
)

def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None:
Expand All @@ -613,7 +626,7 @@ def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send
cur_rank,
prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank),
send_metadata,
send_metadata=send_metadata,
)

def send_forward_recv_backward(
Expand All @@ -622,6 +635,7 @@ def send_forward_recv_backward(
next_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline

Expand All @@ -642,6 +656,7 @@ def send_forward_recv_backward(
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
)

def send_backward_recv_forward(
Expand All @@ -650,6 +665,7 @@ def send_backward_recv_forward(
prev_rank: Optional[int] = None,
send_metadata: bool = True,
metadata_recv: Optional[P2PMetadata] = None,
send_prior_fallback: Optional[bool] = None,
) -> Any:
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline

Expand All @@ -670,6 +686,7 @@ def send_backward_recv_forward(
recv_group=group,
send_metadata=send_metadata,
metadata_recv=metadata_recv,
send_prior_fallback=send_prior_fallback,
)

def p2p_communicate(
Expand Down
Loading