Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,23 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
return np.unravel_index(rank, shape)

@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int:
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)

Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.

Returns:
int: Rank of the coordinate.
"""
return np.ravel_multi_index(coord, shape)

assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)

def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
Expand Down
45 changes: 17 additions & 28 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,10 @@ def recv_forward(self, prev_rank: int = None) -> Any:
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))

return input_tensor

Expand All @@ -193,14 +189,11 @@ def recv_backward(self, next_rank: int = None) -> Any:
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))

return output_tensor_grad

Expand All @@ -211,12 +204,10 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))

def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
Expand All @@ -225,9 +216,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
Loading