Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
Expand All @@ -71,7 +72,7 @@ def __init__(
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
micro_batch_buffer_size or pp_size)
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager)
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
self.model = pp_model or self._shardformer(model, model_policy)

def inference(self, input_list):
Expand Down
7 changes: 6 additions & 1 deletion colossalai/inference/pipeline/microbatch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Status(Enum):
PREFILL = 1
GENERATE = 2
DONE = 3
COOLDOWN = 4


class MicroBatchDescription():
Expand Down Expand Up @@ -52,6 +53,8 @@ def state(self):
# TODO: add the condition for early stopping
if self.cur_length == self.target_length:
return Status.DONE
elif self.cur_length == self.target_length - 1:
return Status.COOLDOWN
else:
return Status.GENERATE

Expand Down Expand Up @@ -184,7 +187,9 @@ def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, ne
return self.cur_state

def export_new_tokens(self):
new_tokens_list = [i.new_tokens[0].tolist() for i in self.mb_descrption_buffer.values()]
new_tokens_list = []
for i in self.mb_descrption_buffer.values():
new_tokens_list.extend(i.new_tokens.tolist())
return new_tokens_list

def is_micro_batch_done(self):
Expand Down
89 changes: 89 additions & 0 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,81 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
return object_list[0]


def _p2p_comm_shape(
tensor_send_next: torch.Tensor,
recv_prev: bool,
peer: int,
group: ProcessGroup,
):
send_next_shape = None
recv_prev_shape = None

if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)

ops = []
if send_next_shape is not None:
send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group)
ops.append(send_next_op)
if recv_prev_shape is not None:
recv_prev_op = dist.P2POp(
dist.irecv,
recv_prev_shape,
peer=peer,
group=group,
)
ops.append(recv_prev_op)

if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()

if recv_prev_shape is not None:
recv_prev_shape = recv_prev_shape.tolist()

return recv_prev_shape


def _p2p_comm(
tensor_send_next: torch.Tensor,
recv_pre: bool,
peer: int,
group: ProcessGroup,
comm_type: torch.dtype = torch.float32,
):
tensor_recv_prev = None
recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group)
if recv_pre:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type)

ops = []
if tensor_send_next is not None:
send_next_op = dist.P2POp(
dist.isend,
tensor_send_next,
peer=peer,
group=group,
)
ops.append(send_next_op)

if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp(
dist.irecv,
tensor_recv_prev,
peer=peer,
group=group,
)
ops.append(recv_prev_op)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return tensor_recv_prev


class PipelineP2PCommunication:

def __init__(self, stage_manager: PipelineStageManager) -> None:
Expand Down Expand Up @@ -220,3 +295,17 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))

def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.

Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if peer is None:
peer = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer))
return recv_tensor
Loading