diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 0c748d725d5d..9236ee0a7bff 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -69,7 +69,8 @@ def __init__( self.pp_size = pp_size self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.mb_manager = MicroBatchManager(new_length, micro_batch_size, micro_batch_buffer_size or pp_size) + self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, + micro_batch_buffer_size or pp_size) self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager) self.model = pp_model or self._shardformer(model, model_policy) diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index f54396bb3747..7f4b14c17748 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict +from typing import Dict, Tuple import torch @@ -13,53 +13,134 @@ class Status(Enum): class MicroBatchDescription(): + """ + This is the class to record the infomation of each microbatch, and also do some update operation. + This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + details, please refer to the doc of these two classes blow. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + """ def __init__( self, - mb_inputs: Dict[str, torch.Tensor], - interval_inputs: Dict[str, torch.Tensor], + inputs_dict: Dict[str, torch.Tensor], + output_dict: Dict[str, torch.Tensor], new_length: int, ) -> None: - if mb_inputs is not None: - assert mb_inputs.get('input_ids') is not None and mb_inputs.get('attention_mask') is not None - self.mb_length = mb_inputs['input_ids'].shape[-1] - self.attn_mask = mb_inputs['attention_mask'] - self.input_ids = mb_inputs['input_ids'] - - elif interval_inputs is not None: - assert interval_inputs.get('hidden_states') is not None - self.mb_length = interval_inputs['hidden_states'].shape[-2] - else: - raise ValueError('mb_inputs and interval_inputs can not be None at the same time') - + assert output_dict.get('hidden_states') is not None + self.mb_length = output_dict['hidden_states'].shape[-2] self.target_length = self.mb_length + new_length self.kv_cache = () - def update(self, kv_cache): + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + if output_dict is not None: + self._update_kvcache(output_dict['past_key_values']) + + def _update_kvcache(self, kv_cache: Tuple): + assert type(kv_cache) == tuple self.kv_cache = kv_cache + @property + def state(self): + """ + Return the state of current micro batch, when current length is equal to target length, + the state is DONE, otherwise GENERATE + + """ + # TODO: add the condition for early stopping + if self.cur_length == self.target_length: + return Status.DONE + else: + return Status.GENERATE + @property def cur_length(self): """ - Return the current sequnence length of micro batch, when there is no kv_cache, the length is mb_length, - otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 + Return the current sequnence length of micro batch """ - if len(self.kv_cache) == 0: + pass + + +class HeadMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + information and the condition to determine the state is different from other stages. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + new_length (int): the new length of the input sequence. + + """ + + def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], + new_length: int) -> None: + super().__init__(inputs_dict, output_dict, new_length) + assert inputs_dict is not None + assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None + self.input_ids = inputs_dict['input_ids'] + self.attn_mask = inputs_dict['attention_mask'] + self.new_tokens = None + + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + super().update(output_dict, new_token) + if new_token is not None: + self._update_newtokens(new_token) + if self.state is not Status.DONE and new_token is not None: + self._update_attnmask() + + def _update_newtokens(self, new_token: torch.Tensor): + if self.new_tokens is None: + self.new_tokens = new_token + else: + self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1) + + def _update_attnmask(self): + self.attn_mask = torch.cat( + (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) + + @property + def cur_length(self): + """ + When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token + + """ + if self.new_tokens is None: return self.mb_length - return self.kv_cache[0][0].shape[-2] + 1 + else: + return self.mb_length + len(self.new_tokens[0]) + + +class BodyMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + + Args: + inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + """ + + def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], + new_length: int) -> None: + super().__init__(inputs_dict, output_dict, new_length) + + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + super().update(output_dict, new_token) @property - def state(self): + def cur_length(self): """ - Return the state of current micro batch, when current length is equal to target length, - the state is DONE, otherwise GENERATE + When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 """ - if self.cur_length == self.target_length: - return Status.DONE + if len(self.kv_cache) == 0: + return self.mb_length else: - return Status.GENERATE + return self.kv_cache[0][0].shape[-2] + 1 class MicroBatchManager(): @@ -67,12 +148,15 @@ class MicroBatchManager(): MicroBatchManager is a class that manages the micro batch. Args: + stage (int): stage id of current stage. new_length (int): the new length of the input sequence. micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + ''' - def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + self.stage = stage self.new_length = new_length self.micro_batch_size = micro_batch_size self.buffer_size = micro_batch_buffer_size @@ -80,33 +164,28 @@ def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_si self.new_tokens_buffer = {} self.idx = 0 - def _add_descrption(self, mb_inputs: Dict[str, torch.Tensor], inter_inputs: Dict[str, torch.Tensor]): - self.mb_descrption_buffer[self.idx] = MicroBatchDescription(mb_inputs, inter_inputs, self.new_length) - - def _update_descrption(self, present_kv): - self.mb_descrption_buffer[self.idx].update(present_kv) - - def _remove_descrption(self): - self.mb_descrption_buffer.pop(self.idx) - - def step(self, mb_inputs=None, inter_inputs=None, present_kv=None): + def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): """ - Update the state if microbatch manager + Update the state if microbatch manager, 2 conditions. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. + 2. For other conditon, only receive the output of previous stage, and update the descrption. Args: - mb_inputs (int, optional): The input of first stage when in prefill, should be a dict like {'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}. - inter_inputs ([type], optional): The input of intermediate stage (the output of previous stage), should be a dict like {'hidden_state': torch.Tensor}. - present_kv ([type], optional): The kvcache of current microbatch in current stage. + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + new_token (torch.Tensor): the new token generated by current stage. """ + # Add descrption first if the descrption is None + if inputs_dict is None and output_dict is None and new_token is None: + return Status.PREFILL if self.mb_descrption_buffer.get(self.idx) is None: - self._add_descrption(mb_inputs, inter_inputs) - self._update_descrption(present_kv) - state = self.cur_state - self.next() - return state + self._add_descrption(inputs_dict, output_dict) + self.cur_descrption.update(output_dict, new_token) + return self.cur_state - def next(self): - self.idx = (self.idx + 1) % self.buffer_size + def export_new_tokens(self): + new_tokens_list = [i.new_tokens[0].tolist() for i in self.mb_descrption_buffer.values()] + return new_tokens_list def is_micro_batch_done(self): if len(self.mb_descrption_buffer) == 0: @@ -114,20 +193,22 @@ def is_micro_batch_done(self): for mb in self.mb_descrption_buffer.values(): if mb.state != Status.DONE: return False - self.mb_descrption_buffer.clear() return True - def add_new_tokens(self, new_token): - if self.idx not in self.new_tokens_buffer: - self.new_tokens_buffer[self.idx] = new_token + def clear(self): + self.mb_descrption_buffer.clear() + + def next(self): + self.idx = (self.idx + 1) % self.buffer_size + + def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]): + if self.stage == 0: + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length) else: - self.new_tokens_buffer[self.idx] = torch.cat([self.new_tokens_buffer[self.idx], new_token], dim=-1) + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length) - def export_new_tokens(self): - list = [item.tolist() for item in self.new_tokens_buffer.values()] - flat_list = [item for sublist in list for item in sublist] - self.new_tokens_buffer.clear() - return flat_list + def _remove_descrption(self): + self.mb_descrption_buffer.pop(self.idx) @property def cur_descrption(self) -> MicroBatchDescription: diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py index 773fb2a07899..f490710c1f7f 100644 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -251,6 +251,12 @@ def gpt2_lmhead_model_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {'logits': lm_logits} + + # Not first stage or before warmup, go through gpt2 model outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, input_ids, past_key_values=past_key_values, @@ -269,24 +275,4 @@ def gpt2_lmhead_model_forward( hidden_states=hidden_states, stage_index=stage_index) - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return outputs - - hidden_states = outputs['hidden_states'] - lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return {'hidden_states': lm_logits, 'past_key_values': outputs['past_key_values']} + return outputs diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py index 3e4ad30f96ed..e51090200f83 100644 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -38,7 +38,9 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): + # make the tie weight lm_head and embedding in the same device to save memory + # if self.pipeline_stage_manager.is_first_stage(): + if self.pipeline_stage_manager.is_first_stage(): held_layers.append(self.model.lm_head) return held_layers diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index e12616655d32..a6ca76a83468 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -16,6 +16,15 @@ class GenerateSchedule(PipelineSchedule): + ''' + GenerateSchedule is a class that handles the pipeline parallel inference. + In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in + this schedule, the out for each encoding progress is on rank0. + + Args: + stage_manager (PipelineStageManager): Pipeline stage manager. + mb_manager (MicroBatchManager): Micro batch manager. + ''' def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager) -> None: super().__init__(stage_manager) @@ -55,35 +64,33 @@ def load_micro_batch(self) -> Any: self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) - def _prepare_stage_inputs(self): - # first stage and in prefill phase - if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: - pre_stage_out = None - model_inputs = self.load_micro_batch() - hidden_states = None - # first stage and in generate phase - elif self.stage_manager.is_first_stage(): - pre_stage_out = self.comm.recv_forward() - model_inputs = self._prepare_next_token(pre_stage_out) - hidden_states = None - # not first stage and in gererate phase - else: - pre_stage_out = self.comm.recv_forward() - model_inputs = { - 'past_key_values': self.mb_manager.cur_kv_cache - } if self.mb_manager.cur_kv_cache is not None else None - hidden_states = pre_stage_out - return pre_stage_out, model_inputs, hidden_states - - def _prepare_next_token(self, inputs: Dict[str, torch.Tensor]): + def _prepare_inputs_for_interval_stage(self): + ''' + Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values + + Returns: + dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` + ''' + model_inputs = { + 'past_key_values': self.mb_manager.cur_kv_cache + } if self.mb_manager.cur_kv_cache is not None else None + return model_inputs + + def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): + ''' + Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` + `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, + `past_key_values` is the past_key_values save in the micro batch manager + + Returns: + dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` + ''' new_mask = self.mb_manager.cur_descrption.attn_mask - new_mask = torch.cat((new_mask, torch.ones((new_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) - self.mb_manager.cur_descrption.attn_mask = new_mask past_key_values = self.mb_manager.cur_descrption.kv_cache - return dict(input_ids=inputs['new_token'], attention_mask=new_mask, past_key_values=past_key_values) + return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values) - def get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: + def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: last_hidden_state = hidden_state[:, -1] input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) return input_ids @@ -93,11 +100,8 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso """Forward one step of the pipeline Args: - model (Module): Model to be run - input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. - criterion (Callable): Criterion to calculate loss. - accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. - outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + model (Module): Model to be run. + data_iter (Iterable): Data iterator. Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). @@ -108,20 +112,46 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # run by round for _ in range(self.round): - state = Status.PREFILL while self.mb_manager.is_micro_batch_done() is False: - pre_stage_out, model_inputs, hidden_states = self._prepare_stage_inputs() - - output_obj = model_forward(model, model_inputs, hidden_states) - - past_key_values = output_obj.get('past_key_values', None) - state = self.mb_manager.step(model_inputs, pre_stage_out, past_key_values) - if self.stage_manager.is_last_stage(): - new_token = self.get_token_id(output_obj['hidden_states']) - self.mb_manager.add_new_tokens(new_token) - if state is not Status.DONE: - self.comm.send_forward({'new_token': new_token}) + inputs_dict = None + new_token = None + output_dict = None + + # First stage and in PREFILL phase, just load the inputs + if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: + inputs_dict = self.load_micro_batch() + output_dict = model_forward(model, inputs_dict, None) + self.mb_manager.step(inputs_dict, output_dict, None) + # In GENERATE phase else: - self.comm.send_forward({'hidden_states': output_obj['hidden_states']}) - output_sequence.extend(self.mb_manager.export_new_tokens()) + # Get hidden_states from previous stage + hidden_states = self.comm.recv_forward() + if self.stage_manager.is_first_stage(): + # First just generate a new token + assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + logits = model_forward(model, None, hidden_states) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {output.keys()}" + new_token = self._get_token_id(logits['logits']) + self.mb_manager.step(None, None, new_token) + # If the current micro batch is not DONE, go through blocks + if self.mb_manager.cur_state is Status.GENERATE: + inputs_dict = self._prepare_inputs_for_new_token(new_token) + output_dict = model_forward(model, inputs_dict, None) + self.mb_manager.step(inputs_dict, output_dict, None) + else: + assert hidden_states is not None, "When not first stage, the hidden states should not be None" + inputs_dict = self._prepare_inputs_for_interval_stage() + output_dict = model_forward(model, inputs_dict, hidden_states) + self.mb_manager.step(inputs_dict, output_dict, None) + + # Current microbatch is not DONE, send hidden_state to next stage + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state is Status.GENERATE: + self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) + + self.mb_manager.next() + + # All microbatch in current round is DONE + if self.stage_manager.is_first_stage(): + output_sequence.extend(self.mb_manager.export_new_tokens()) + self.mb_manager.clear() return output_sequence diff --git a/tests/test_generate/test_pipeline_infer.py b/tests/test_generate/test_pipeline_infer.py index 5bc2f1857536..47cf9e78d138 100644 --- a/tests/test_generate/test_pipeline_infer.py +++ b/tests/test_generate/test_pipeline_infer.py @@ -34,7 +34,7 @@ def pipeline_inference_test(pp_size, new_length, micro_batch_size): new_length=new_length, micro_batch_size=micro_batch_size) output = engine.inference([inputs]) - if dist.get_rank() == dist.get_world_size() - 1: + if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}"