diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 9236ee0a7bff..39366d2d69da 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -60,6 +60,7 @@ def __init__( new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, + verbose: bool = False, # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, do_sample: bool = False, @@ -71,7 +72,7 @@ def __init__( self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) - self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager) + self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) self.model = pp_model or self._shardformer(model, model_policy) def inference(self, input_list): diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 7f4b14c17748..b6b008442cfd 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -10,6 +10,7 @@ class Status(Enum): PREFILL = 1 GENERATE = 2 DONE = 3 + COOLDOWN = 4 class MicroBatchDescription(): @@ -52,6 +53,8 @@ def state(self): # TODO: add the condition for early stopping if self.cur_length == self.target_length: return Status.DONE + elif self.cur_length == self.target_length - 1: + return Status.COOLDOWN else: return Status.GENERATE @@ -184,7 +187,9 @@ def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, ne return self.cur_state def export_new_tokens(self): - new_tokens_list = [i.new_tokens[0].tolist() for i in self.mb_descrption_buffer.values()] + new_tokens_list = [] + for i in self.mb_descrption_buffer.values(): + new_tokens_list.extend(i.new_tokens.tolist()) return new_tokens_list def is_micro_batch_done(self): diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index aed85cf91512..227ad2daca0e 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -159,6 +159,81 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: return object_list[0] +def _p2p_comm_shape( + tensor_send_next: torch.Tensor, + recv_prev: bool, + peer: int, + group: ProcessGroup, +): + send_next_shape = None + recv_prev_shape = None + + if tensor_send_next is not None: + send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64) + if recv_prev: + recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + + ops = [] + if send_next_shape is not None: + send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group) + ops.append(send_next_op) + if recv_prev_shape is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + recv_prev_shape, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + if recv_prev_shape is not None: + recv_prev_shape = recv_prev_shape.tolist() + + return recv_prev_shape + + +def _p2p_comm( + tensor_send_next: torch.Tensor, + recv_pre: bool, + peer: int, + group: ProcessGroup, + comm_type: torch.dtype = torch.float32, +): + tensor_recv_prev = None + recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) + if recv_pre: + tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type) + + ops = [] + if tensor_send_next is not None: + send_next_op = dist.P2POp( + dist.isend, + tensor_send_next, + peer=peer, + group=group, + ) + ops.append(send_next_op) + + if tensor_recv_prev is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + tensor_recv_prev, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + return tensor_recv_prev + + class PipelineP2PCommunication: def __init__(self, stage_manager: PipelineStageManager) -> None: @@ -220,3 +295,17 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + + def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None: + """ + Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if peer is None: + peer = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer)) + return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index a6ca76a83468..85feebea6e6e 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -1,5 +1,6 @@ +import time from functools import partial -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Iterable, List, Optional, Union import torch import torch.cuda @@ -15,6 +16,17 @@ from .base import PipelineSchedule +class ActionIntervalBuffer(): + + def __int__(self): + self.hidden_states = None + self.new_token = None + + def clear(self): + self.hidden_states = None + self.new_token = None + + class GenerateSchedule(PipelineSchedule): ''' GenerateSchedule is a class that handles the pipeline parallel inference. @@ -22,11 +34,12 @@ class GenerateSchedule(PipelineSchedule): this schedule, the out for each encoding progress is on rank0. Args: - stage_manager (PipelineStageManager): Pipeline stage manager. - mb_manager (MicroBatchManager): Micro batch manager. + stage_manager (`PipelineStageManager`): Pipeline stage manager. + mb_manager (`MicroBatchManager`): Micro batch manager. + verbose (bool): Whether to verbose the information of the pipeline. ''' - def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager) -> None: + def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: super().__init__(stage_manager) self.comm = PipelineP2PCommunication(stage_manager) self.mb_manager = mb_manager @@ -35,6 +48,8 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None + self.verbose = verbose + self.action_interval_buffer = ActionIntervalBuffer() def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -95,9 +110,164 @@ def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) return input_ids - @torch.no_grad() + def _recv_pre_stage(self) -> Any: + ''' + Receive the output from previous stage + + Returns: + Any: The output from previous stage + ''' + if self.stage_manager.num_stages == 2: + return self.comm.p2p_recv() + return self.comm.recv_forward() + + def LoadStageAction(self, model: Module) -> None: + """ + In this action, 1.load micro_batch 2.do the forward 3.step to update + """ + inputs_dict = self.load_micro_batch() + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def GenTokenAction(self, model: Module): + """ + In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update + """ + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + hidden_states = {'hidden_states': hidden_states} + logits = model_forward(model, None, hidden_states) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits['logits']) + + self.mb_manager.step(None, None, new_token) + self.action_interval_buffer.new_token = new_token + self.action_interval_buffer.hidden_states = None + + def HeadEncodingAction(self, model: Module): + """ + In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update + """ + new_token = self.action_interval_buffer.new_token + assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" + inputs_dict = self._prepare_inputs_for_new_token(new_token) + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def BodyEncodingAction(self, model: Module): + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When not first stage, the hidden states should not be None" + inputs_dict = self._prepare_inputs_for_interval_stage() + hidden_states = {'hidden_states': hidden_states} + output_dict = model_forward(model, inputs_dict, hidden_states) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def CommAction(self, recv_pre: bool) -> torch.Tensor: + """ + In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage + """ + hidden_states = self.action_interval_buffer.hidden_states + ret = self.comm.p2p_communicate(hidden_states, recv_pre) + + self.action_interval_buffer.hidden_states = ret + + def genAction(self, model: Module): + """ + In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done + at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will + generate a sequence action for current state, and do the action one by one. + + Args: + model (Module): Model to be run. + + Returns: + List[Callable]: A list of action, each action is a callable function, and it will be called in order. + """ + actions = [] + if self.stage_manager.is_first_stage(): + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self.CommAction, False)) + actions.append(partial(self.LoadStageAction, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.GenTokenAction, model)) + actions.append(partial(self.HeadEncodingAction, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.GenTokenAction, model)) + # other stage + else: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.BodyEncodingAction, model)) + + return actions + + def verbose_info(self, timestamps: List): + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1, + len(timestamp) - 1)) / (len(timestamp) - 2)) + end2end.append(timestamp[-1] - timestamp[0]) + print(f"Average prefill time: {sum(prefill)/len(prefill)}") + print(f"Average encode time: {sum(encoder)/len(encoder)}") + print(f"Average end2end time: {sum(end2end)/len(end2end)}") + def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: - """Forward one step of the pipeline + if self.stage_manager.num_stages == 2: + return self.generate_step_p2p(model, data_iter) + else: + return self.generate_step_broadcast(model, data_iter) + + @torch.no_grad() + def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be + blocked, so we use `P2POp` asynchronous communication method. + + Args: + model (Module): Model to be run. + data_iter (Iterable): Data iterator. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + output_sequence = [] + self.load_batch(data_iter) + model.eval() + + whole_timestamp = [] + + #run by round + for _ in range(self.round): + self.action_interval_buffer.clear() + while self.mb_manager.is_micro_batch_done() is False: + actions = self.genAction(model) + for action in actions: + action() + self.mb_manager.next() + # All microbatch in current round is DONE + if self.stage_manager.is_first_stage(): + output_sequence.extend(self.mb_manager.export_new_tokens()) + else: + self.CommAction(False) + self.mb_manager.clear() + + return output_sequence + + @torch.no_grad() + def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline Args: model (Module): Model to be run. @@ -110,8 +280,11 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso self.load_batch(data_iter) model.eval() + whole_timestamp = [] # run by round for _ in range(self.round): + timestampes = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -120,6 +293,8 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # First stage and in PREFILL phase, just load the inputs if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + timestampes[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) # In GENERATE phase @@ -130,7 +305,9 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # First just generate a new token assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {output.keys()}" + if self.verbose and self.stage_manager.is_first_stage(): + timestampes[self.mb_manager.idx].append(time.time()) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks @@ -154,4 +331,10 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(timestampes) + + if self.verbose and self.stage_manager.is_first_stage(): + self.verbose_info(whole_timestamp) + return output_sequence