From fdeb661b807a85a85e6e87d92a775817f4126b1c Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:18:52 +0800 Subject: [PATCH 1/7] [pipeline inference] pipeline inference (#4492) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * fix CI * add cache clear * fix code error * fix typo --- colossalai/inference/__init__.py | 3 + colossalai/inference/pipeline/__init__.py | 3 + colossalai/inference/pipeline/engine.py | 93 ++++++ .../inference/pipeline/microbatch_manager.py | 150 +++++++++ .../inference/pipeline/modeling/__init__.py | 0 .../inference/pipeline/modeling/gpt2.py | 292 ++++++++++++++++++ .../inference/pipeline/policy/gpt2_ppinfer.py | 69 +++++ colossalai/inference/pipeline/utils.py | 35 +++ colossalai/pipeline/schedule/generate.py | 127 ++++++++ colossalai/pipeline/stage_manager.py | 2 + .../test_low_level_zero_checkpoint_io.py | 1 + tests/test_generate/test_pipeline_infer.py | 63 ++++ 12 files changed, 838 insertions(+) create mode 100644 colossalai/inference/pipeline/__init__.py create mode 100644 colossalai/inference/pipeline/engine.py create mode 100644 colossalai/inference/pipeline/microbatch_manager.py create mode 100644 colossalai/inference/pipeline/modeling/__init__.py create mode 100644 colossalai/inference/pipeline/modeling/gpt2.py create mode 100644 colossalai/inference/pipeline/policy/gpt2_ppinfer.py create mode 100644 colossalai/inference/pipeline/utils.py create mode 100644 colossalai/pipeline/schedule/generate.py create mode 100644 tests/test_generate/test_pipeline_infer.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index e69de29bb2d1..db33ae6fe998 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -0,0 +1,3 @@ +from .pipeline import PPInferEngine + +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py new file mode 100644 index 000000000000..aff4568f7d08 --- /dev/null +++ b/colossalai/inference/pipeline/__init__.py @@ -0,0 +1,3 @@ +from .engine import PPInferEngine + +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py new file mode 100644 index 000000000000..0c748d725d5d --- /dev/null +++ b/colossalai/inference/pipeline/engine.py @@ -0,0 +1,93 @@ +import re +from functools import partial +from types import MethodType +from typing import Callable, List, Optional, Set + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule.generate import GenerateSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer._utils import getattr_ +from colossalai.shardformer.policies.base_policy import Policy + +from .microbatch_manager import MicroBatchManager +from .policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from .utils import get_suffix_name, set_tensors_to_none + + +class PPInferEngine: + ''' + PPInferEngine is a class that handles the pipeline parallel inference. + + Args: + pp_size (int): the number of pipeline stages. + pp_model (`nn.Module`): the model already in pipeline parallelism style. + model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. + model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + new_length (int): the new length of the input sequence. + early_stopping (bool): whether to stop early. + + Example: + + ```python + from colossalai.ppinference import PPInferEngine + from transformers import GPT2LMHeadModel, GPT2Tokenizer + + model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') + # assume the model is infered with 4 pipeline stages + inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding}) + + input = ["Hello, my dog is cute, and I like"] + tokenized_input = tokenizer(input, return_tensors='pt') + output = engine.inference([tokenized_input]) + ``` + + ''' + + def __init__( + self, + pp_size: int, + pp_model: nn.Module = None, + model: nn.Module = None, + model_policy: Policy = None, + new_length: int = 32, + micro_batch_size: int = 1, + micro_batch_buffer_size: int = None, + # TODO: implement early_stopping, and various gerneration options + early_stopping: bool = False, + do_sample: bool = False, + num_beams: int = 1, + ) -> None: + assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + self.pp_size = pp_size + self.pg_mesh = ProcessGroupMesh(pp_size) + self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) + self.mb_manager = MicroBatchManager(new_length, micro_batch_size, micro_batch_buffer_size or pp_size) + self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager) + self.model = pp_model or self._shardformer(model, model_policy) + + def inference(self, input_list): + out = self.schedule.generate_step(self.model, iter(input_list)) + return out + + def _shardformer(self, model, model_policy): + shardconfig = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model.cuda() diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py new file mode 100644 index 000000000000..f54396bb3747 --- /dev/null +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -0,0 +1,150 @@ +from enum import Enum +from typing import Dict + +import torch + +__all__ = 'MicroBatchManager' + + +class Status(Enum): + PREFILL = 1 + GENERATE = 2 + DONE = 3 + + +class MicroBatchDescription(): + + def __init__( + self, + mb_inputs: Dict[str, torch.Tensor], + interval_inputs: Dict[str, torch.Tensor], + new_length: int, + ) -> None: + if mb_inputs is not None: + assert mb_inputs.get('input_ids') is not None and mb_inputs.get('attention_mask') is not None + self.mb_length = mb_inputs['input_ids'].shape[-1] + self.attn_mask = mb_inputs['attention_mask'] + self.input_ids = mb_inputs['input_ids'] + + elif interval_inputs is not None: + assert interval_inputs.get('hidden_states') is not None + self.mb_length = interval_inputs['hidden_states'].shape[-2] + else: + raise ValueError('mb_inputs and interval_inputs can not be None at the same time') + + self.target_length = self.mb_length + new_length + self.kv_cache = () + + def update(self, kv_cache): + self.kv_cache = kv_cache + + @property + def cur_length(self): + """ + Return the current sequnence length of micro batch, when there is no kv_cache, the length is mb_length, + otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 + + """ + if len(self.kv_cache) == 0: + return self.mb_length + return self.kv_cache[0][0].shape[-2] + 1 + + @property + def state(self): + """ + Return the state of current micro batch, when current length is equal to target length, + the state is DONE, otherwise GENERATE + + """ + if self.cur_length == self.target_length: + return Status.DONE + else: + return Status.GENERATE + + +class MicroBatchManager(): + ''' + MicroBatchManager is a class that manages the micro batch. + + Args: + new_length (int): the new length of the input sequence. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + ''' + + def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + self.new_length = new_length + self.micro_batch_size = micro_batch_size + self.buffer_size = micro_batch_buffer_size + self.mb_descrption_buffer = {} + self.new_tokens_buffer = {} + self.idx = 0 + + def _add_descrption(self, mb_inputs: Dict[str, torch.Tensor], inter_inputs: Dict[str, torch.Tensor]): + self.mb_descrption_buffer[self.idx] = MicroBatchDescription(mb_inputs, inter_inputs, self.new_length) + + def _update_descrption(self, present_kv): + self.mb_descrption_buffer[self.idx].update(present_kv) + + def _remove_descrption(self): + self.mb_descrption_buffer.pop(self.idx) + + def step(self, mb_inputs=None, inter_inputs=None, present_kv=None): + """ + Update the state if microbatch manager + + Args: + mb_inputs (int, optional): The input of first stage when in prefill, should be a dict like {'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}. + inter_inputs ([type], optional): The input of intermediate stage (the output of previous stage), should be a dict like {'hidden_state': torch.Tensor}. + present_kv ([type], optional): The kvcache of current microbatch in current stage. + """ + if self.mb_descrption_buffer.get(self.idx) is None: + self._add_descrption(mb_inputs, inter_inputs) + self._update_descrption(present_kv) + state = self.cur_state + self.next() + return state + + def next(self): + self.idx = (self.idx + 1) % self.buffer_size + + def is_micro_batch_done(self): + if len(self.mb_descrption_buffer) == 0: + return False + for mb in self.mb_descrption_buffer.values(): + if mb.state != Status.DONE: + return False + self.mb_descrption_buffer.clear() + return True + + def add_new_tokens(self, new_token): + if self.idx not in self.new_tokens_buffer: + self.new_tokens_buffer[self.idx] = new_token + else: + self.new_tokens_buffer[self.idx] = torch.cat([self.new_tokens_buffer[self.idx], new_token], dim=-1) + + def export_new_tokens(self): + list = [item.tolist() for item in self.new_tokens_buffer.values()] + flat_list = [item for sublist in list for item in sublist] + self.new_tokens_buffer.clear() + return flat_list + + @property + def cur_descrption(self) -> MicroBatchDescription: + return self.mb_descrption_buffer.get(self.idx) + + @property + def cur_kv_cache(self): + if self.cur_descrption is None: + return None + return self.cur_descrption.kv_cache + + @property + def cur_state(self): + """ + Return the state of current micro batch, when current descrption is None, the state is PREFILL + + """ + if self.cur_descrption is None: + return Status.PREFILL + return self.cur_descrption.state diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/pipeline/modeling/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py new file mode 100644 index 000000000000..773fb2a07899 --- /dev/null +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -0,0 +1,292 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class GPT2PipelineForwards: + ''' + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + ''' + + @staticmethod + def gpt2_model_forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + else: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i, layer_past in zip(range(start_idx, end_idx), past_key_values): + block = self.h[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return {'hidden_states': hidden_states, 'past_key_values': presents} + + @staticmethod + def gpt2_lmhead_model_forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return outputs + + hidden_states = outputs['hidden_states'] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return {'hidden_states': lm_logits, 'past_key_values': outputs['past_key_values']} diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py new file mode 100644 index 000000000000..3e4ad30f96ed --- /dev/null +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -0,0 +1,69 @@ +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.policies.gpt2 import GPT2Policy + +from ..modeling.gpt2 import GPT2PipelineForwards + + +class GPT2LMHeadModelPipelinePolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py new file mode 100644 index 000000000000..1a6e8a519397 --- /dev/null +++ b/colossalai/inference/pipeline/utils.py @@ -0,0 +1,35 @@ +from typing import List, Optional, Set + +import torch.nn as nn + +from colossalai.shardformer._utils import getattr_, setattr_ + + +def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None: + """ + Set all parameters and buffers of model to None + + Args: + model (nn.Module): The model to set + """ + for module_suffix in include: + set_module = getattr_(model, module_suffix) + for n, p in set_module.named_parameters(): + setattr_(set_module, n, None) + for n, buf in set_module.named_buffers(): + setattr_(set_module, n, None) + setattr_(model, module_suffix, None) + + +def get_suffix_name(suffix: str, name: str): + """ + Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit, + and 'name' when `suffix` is empty. + + Args: + suffix (str): The suffix of the suffix module + name (str): The name of the current module + """ + point = '' if suffix is '' else '.' + suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}' + return suffix_name diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py new file mode 100644 index 000000000000..e12616655d32 --- /dev/null +++ b/colossalai/pipeline/schedule/generate.py @@ -0,0 +1,127 @@ +from functools import partial +from typing import Any, Dict, Iterable, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import get_batch_size, get_micro_batch, model_forward, to_device +from .base import PipelineSchedule + + +class GenerateSchedule(PipelineSchedule): + + def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager) -> None: + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.mb_manager = mb_manager + self.microbatch_size = mb_manager.micro_batch_size + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.num_microbatches: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = 0 + assert self.batch_size % self.microbatch_size == 0, \ + f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" + self.num_microbatches = self.batch_size // self.microbatch_size + self.round = self.num_microbatches // self.stage_manager.num_stages + + def load_micro_batch(self) -> Any: + """Load a micro batch from the current batch. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) + self.microbatch_offset += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def _prepare_stage_inputs(self): + # first stage and in prefill phase + if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: + pre_stage_out = None + model_inputs = self.load_micro_batch() + hidden_states = None + # first stage and in generate phase + elif self.stage_manager.is_first_stage(): + pre_stage_out = self.comm.recv_forward() + model_inputs = self._prepare_next_token(pre_stage_out) + hidden_states = None + # not first stage and in gererate phase + else: + pre_stage_out = self.comm.recv_forward() + model_inputs = { + 'past_key_values': self.mb_manager.cur_kv_cache + } if self.mb_manager.cur_kv_cache is not None else None + hidden_states = pre_stage_out + return pre_stage_out, model_inputs, hidden_states + + def _prepare_next_token(self, inputs: Dict[str, torch.Tensor]): + new_mask = self.mb_manager.cur_descrption.attn_mask + new_mask = torch.cat((new_mask, torch.ones((new_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) + self.mb_manager.cur_descrption.attn_mask = new_mask + past_key_values = self.mb_manager.cur_descrption.kv_cache + + return dict(input_ids=inputs['new_token'], attention_mask=new_mask, past_key_values=past_key_values) + + def get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: + last_hidden_state = hidden_state[:, -1] + input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) + return input_ids + + @torch.no_grad() + def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + + Args: + model (Module): Model to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + output_sequence = [] + self.load_batch(data_iter) + model.eval() + + # run by round + for _ in range(self.round): + state = Status.PREFILL + while self.mb_manager.is_micro_batch_done() is False: + pre_stage_out, model_inputs, hidden_states = self._prepare_stage_inputs() + + output_obj = model_forward(model, model_inputs, hidden_states) + + past_key_values = output_obj.get('past_key_values', None) + state = self.mb_manager.step(model_inputs, pre_stage_out, past_key_values) + if self.stage_manager.is_last_stage(): + new_token = self.get_token_id(output_obj['hidden_states']) + self.mb_manager.add_new_tokens(new_token) + if state is not Status.DONE: + self.comm.send_forward({'new_token': new_token}) + else: + self.comm.send_forward({'hidden_states': output_obj['hidden_states']}) + output_sequence.extend(self.mb_manager.export_new_tokens()) + return output_sequence diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b79867a2c651..d988015ceeda 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -12,6 +12,7 @@ class PipelineStageManager: Args: pg_mesh (ProcessGroupMesh): Process group mesh. pipeline_axis (int): The axis along which the pipeline is constructed. + is_virtual (bool): Whether to use circle p2p communication, it will make the first and last stage communicate with each other. Attributes: num_stages (int): Number of stages in the pipeline. @@ -24,6 +25,7 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord coord = self.pg_mesh.coordinate() # the prev rank of rank0 is the last rank diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 8a4724c8a82c..e7f44f97e3cf 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -66,6 +66,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) + torch.cuda.empty_cache() def run_dist(rank, world_size, port): diff --git a/tests/test_generate/test_pipeline_infer.py b/tests/test_generate/test_pipeline_infer.py new file mode 100644 index 000000000000..5bc2f1857536 --- /dev/null +++ b/tests/test_generate/test_pipeline_infer.py @@ -0,0 +1,63 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers + +import colossalai +from colossalai.inference.pipeline.engine import PPInferEngine +from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + + +def data_gen(): + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +inputs = data_gen() +for k, v in inputs.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 16 + inputs[k] = v.to('cuda').repeat(*new_shape) + + +def pipeline_inference_test(pp_size, new_length, micro_batch_size): + model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) + engine = PPInferEngine(pp_size=pp_size, + model=model, + model_policy=GPT2LMHeadModelPipelinePolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size) + output = engine.inference([inputs]) + if dist.get_rank() == dist.get_world_size() - 1: + assert len(output[0]) == new_length, f"{len(output)}, {new_length}" + + +@parameterize('pp_size', [4]) +@parameterize('new_length', [4, 8, 16]) +@parameterize('micro_batch_size', [1, 4]) +@clear_cache_before_run() +def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): + pipeline_inference_test(pp_size, new_length, micro_batch_size) + torch.cuda.empty_cache() + + +def check_pipeline_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_pipeline_inference_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_pipeline_inference(): + spawn(check_pipeline_inference, nprocs=4) + + +if __name__ == '__main__': + test_pipeline_inference() From 99237c8358c0b739339055f46349e739be45ccb7 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 7 Sep 2023 11:11:44 +0800 Subject: [PATCH 2/7] [Pipeline inference] Modify to tieweight (#4599) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * modify the way of saving newtokens * modify to tieweight * modify test * remove unused file * solve review * add docstring --- colossalai/inference/pipeline/engine.py | 3 +- .../inference/pipeline/microbatch_manager.py | 197 ++++++++++++------ .../inference/pipeline/modeling/gpt2.py | 28 +-- .../inference/pipeline/policy/gpt2_ppinfer.py | 4 +- colossalai/pipeline/schedule/generate.py | 118 +++++++---- tests/test_generate/test_pipeline_infer.py | 2 +- 6 files changed, 226 insertions(+), 126 deletions(-) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 0c748d725d5d..9236ee0a7bff 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -69,7 +69,8 @@ def __init__( self.pp_size = pp_size self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.mb_manager = MicroBatchManager(new_length, micro_batch_size, micro_batch_buffer_size or pp_size) + self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, + micro_batch_buffer_size or pp_size) self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager) self.model = pp_model or self._shardformer(model, model_policy) diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index f54396bb3747..7f4b14c17748 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict +from typing import Dict, Tuple import torch @@ -13,53 +13,134 @@ class Status(Enum): class MicroBatchDescription(): + """ + This is the class to record the infomation of each microbatch, and also do some update operation. + This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + details, please refer to the doc of these two classes blow. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + """ def __init__( self, - mb_inputs: Dict[str, torch.Tensor], - interval_inputs: Dict[str, torch.Tensor], + inputs_dict: Dict[str, torch.Tensor], + output_dict: Dict[str, torch.Tensor], new_length: int, ) -> None: - if mb_inputs is not None: - assert mb_inputs.get('input_ids') is not None and mb_inputs.get('attention_mask') is not None - self.mb_length = mb_inputs['input_ids'].shape[-1] - self.attn_mask = mb_inputs['attention_mask'] - self.input_ids = mb_inputs['input_ids'] - - elif interval_inputs is not None: - assert interval_inputs.get('hidden_states') is not None - self.mb_length = interval_inputs['hidden_states'].shape[-2] - else: - raise ValueError('mb_inputs and interval_inputs can not be None at the same time') - + assert output_dict.get('hidden_states') is not None + self.mb_length = output_dict['hidden_states'].shape[-2] self.target_length = self.mb_length + new_length self.kv_cache = () - def update(self, kv_cache): + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + if output_dict is not None: + self._update_kvcache(output_dict['past_key_values']) + + def _update_kvcache(self, kv_cache: Tuple): + assert type(kv_cache) == tuple self.kv_cache = kv_cache + @property + def state(self): + """ + Return the state of current micro batch, when current length is equal to target length, + the state is DONE, otherwise GENERATE + + """ + # TODO: add the condition for early stopping + if self.cur_length == self.target_length: + return Status.DONE + else: + return Status.GENERATE + @property def cur_length(self): """ - Return the current sequnence length of micro batch, when there is no kv_cache, the length is mb_length, - otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 + Return the current sequnence length of micro batch """ - if len(self.kv_cache) == 0: + pass + + +class HeadMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + information and the condition to determine the state is different from other stages. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + new_length (int): the new length of the input sequence. + + """ + + def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], + new_length: int) -> None: + super().__init__(inputs_dict, output_dict, new_length) + assert inputs_dict is not None + assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None + self.input_ids = inputs_dict['input_ids'] + self.attn_mask = inputs_dict['attention_mask'] + self.new_tokens = None + + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + super().update(output_dict, new_token) + if new_token is not None: + self._update_newtokens(new_token) + if self.state is not Status.DONE and new_token is not None: + self._update_attnmask() + + def _update_newtokens(self, new_token: torch.Tensor): + if self.new_tokens is None: + self.new_tokens = new_token + else: + self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1) + + def _update_attnmask(self): + self.attn_mask = torch.cat( + (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) + + @property + def cur_length(self): + """ + When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token + + """ + if self.new_tokens is None: return self.mb_length - return self.kv_cache[0][0].shape[-2] + 1 + else: + return self.mb_length + len(self.new_tokens[0]) + + +class BodyMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + + Args: + inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + """ + + def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], + new_length: int) -> None: + super().__init__(inputs_dict, output_dict, new_length) + + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + super().update(output_dict, new_token) @property - def state(self): + def cur_length(self): """ - Return the state of current micro batch, when current length is equal to target length, - the state is DONE, otherwise GENERATE + When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 """ - if self.cur_length == self.target_length: - return Status.DONE + if len(self.kv_cache) == 0: + return self.mb_length else: - return Status.GENERATE + return self.kv_cache[0][0].shape[-2] + 1 class MicroBatchManager(): @@ -67,12 +148,15 @@ class MicroBatchManager(): MicroBatchManager is a class that manages the micro batch. Args: + stage (int): stage id of current stage. new_length (int): the new length of the input sequence. micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + ''' - def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + self.stage = stage self.new_length = new_length self.micro_batch_size = micro_batch_size self.buffer_size = micro_batch_buffer_size @@ -80,33 +164,28 @@ def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_si self.new_tokens_buffer = {} self.idx = 0 - def _add_descrption(self, mb_inputs: Dict[str, torch.Tensor], inter_inputs: Dict[str, torch.Tensor]): - self.mb_descrption_buffer[self.idx] = MicroBatchDescription(mb_inputs, inter_inputs, self.new_length) - - def _update_descrption(self, present_kv): - self.mb_descrption_buffer[self.idx].update(present_kv) - - def _remove_descrption(self): - self.mb_descrption_buffer.pop(self.idx) - - def step(self, mb_inputs=None, inter_inputs=None, present_kv=None): + def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): """ - Update the state if microbatch manager + Update the state if microbatch manager, 2 conditions. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. + 2. For other conditon, only receive the output of previous stage, and update the descrption. Args: - mb_inputs (int, optional): The input of first stage when in prefill, should be a dict like {'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}. - inter_inputs ([type], optional): The input of intermediate stage (the output of previous stage), should be a dict like {'hidden_state': torch.Tensor}. - present_kv ([type], optional): The kvcache of current microbatch in current stage. + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + new_token (torch.Tensor): the new token generated by current stage. """ + # Add descrption first if the descrption is None + if inputs_dict is None and output_dict is None and new_token is None: + return Status.PREFILL if self.mb_descrption_buffer.get(self.idx) is None: - self._add_descrption(mb_inputs, inter_inputs) - self._update_descrption(present_kv) - state = self.cur_state - self.next() - return state + self._add_descrption(inputs_dict, output_dict) + self.cur_descrption.update(output_dict, new_token) + return self.cur_state - def next(self): - self.idx = (self.idx + 1) % self.buffer_size + def export_new_tokens(self): + new_tokens_list = [i.new_tokens[0].tolist() for i in self.mb_descrption_buffer.values()] + return new_tokens_list def is_micro_batch_done(self): if len(self.mb_descrption_buffer) == 0: @@ -114,20 +193,22 @@ def is_micro_batch_done(self): for mb in self.mb_descrption_buffer.values(): if mb.state != Status.DONE: return False - self.mb_descrption_buffer.clear() return True - def add_new_tokens(self, new_token): - if self.idx not in self.new_tokens_buffer: - self.new_tokens_buffer[self.idx] = new_token + def clear(self): + self.mb_descrption_buffer.clear() + + def next(self): + self.idx = (self.idx + 1) % self.buffer_size + + def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]): + if self.stage == 0: + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length) else: - self.new_tokens_buffer[self.idx] = torch.cat([self.new_tokens_buffer[self.idx], new_token], dim=-1) + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length) - def export_new_tokens(self): - list = [item.tolist() for item in self.new_tokens_buffer.values()] - flat_list = [item for sublist in list for item in sublist] - self.new_tokens_buffer.clear() - return flat_list + def _remove_descrption(self): + self.mb_descrption_buffer.pop(self.idx) @property def cur_descrption(self) -> MicroBatchDescription: diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py index 773fb2a07899..f490710c1f7f 100644 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -251,6 +251,12 @@ def gpt2_lmhead_model_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {'logits': lm_logits} + + # Not first stage or before warmup, go through gpt2 model outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, input_ids, past_key_values=past_key_values, @@ -269,24 +275,4 @@ def gpt2_lmhead_model_forward( hidden_states=hidden_states, stage_index=stage_index) - # If not at the last stage, return hidden_states as in GPT2Model - if not stage_manager.is_last_stage(): - return outputs - - hidden_states = outputs['hidden_states'] - lm_logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return {'hidden_states': lm_logits, 'past_key_values': outputs['past_key_values']} + return outputs diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py index 3e4ad30f96ed..e51090200f83 100644 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -38,7 +38,9 @@ def module_policy(self): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): + # make the tie weight lm_head and embedding in the same device to save memory + # if self.pipeline_stage_manager.is_first_stage(): + if self.pipeline_stage_manager.is_first_stage(): held_layers.append(self.model.lm_head) return held_layers diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index e12616655d32..a6ca76a83468 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -16,6 +16,15 @@ class GenerateSchedule(PipelineSchedule): + ''' + GenerateSchedule is a class that handles the pipeline parallel inference. + In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in + this schedule, the out for each encoding progress is on rank0. + + Args: + stage_manager (PipelineStageManager): Pipeline stage manager. + mb_manager (MicroBatchManager): Micro batch manager. + ''' def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager) -> None: super().__init__(stage_manager) @@ -55,35 +64,33 @@ def load_micro_batch(self) -> Any: self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) - def _prepare_stage_inputs(self): - # first stage and in prefill phase - if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: - pre_stage_out = None - model_inputs = self.load_micro_batch() - hidden_states = None - # first stage and in generate phase - elif self.stage_manager.is_first_stage(): - pre_stage_out = self.comm.recv_forward() - model_inputs = self._prepare_next_token(pre_stage_out) - hidden_states = None - # not first stage and in gererate phase - else: - pre_stage_out = self.comm.recv_forward() - model_inputs = { - 'past_key_values': self.mb_manager.cur_kv_cache - } if self.mb_manager.cur_kv_cache is not None else None - hidden_states = pre_stage_out - return pre_stage_out, model_inputs, hidden_states - - def _prepare_next_token(self, inputs: Dict[str, torch.Tensor]): + def _prepare_inputs_for_interval_stage(self): + ''' + Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values + + Returns: + dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` + ''' + model_inputs = { + 'past_key_values': self.mb_manager.cur_kv_cache + } if self.mb_manager.cur_kv_cache is not None else None + return model_inputs + + def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): + ''' + Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` + `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, + `past_key_values` is the past_key_values save in the micro batch manager + + Returns: + dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` + ''' new_mask = self.mb_manager.cur_descrption.attn_mask - new_mask = torch.cat((new_mask, torch.ones((new_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) - self.mb_manager.cur_descrption.attn_mask = new_mask past_key_values = self.mb_manager.cur_descrption.kv_cache - return dict(input_ids=inputs['new_token'], attention_mask=new_mask, past_key_values=past_key_values) + return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values) - def get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: + def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: last_hidden_state = hidden_state[:, -1] input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) return input_ids @@ -93,11 +100,8 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso """Forward one step of the pipeline Args: - model (Module): Model to be run - input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. - criterion (Callable): Criterion to calculate loss. - accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. - outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + model (Module): Model to be run. + data_iter (Iterable): Data iterator. Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). @@ -108,20 +112,46 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # run by round for _ in range(self.round): - state = Status.PREFILL while self.mb_manager.is_micro_batch_done() is False: - pre_stage_out, model_inputs, hidden_states = self._prepare_stage_inputs() - - output_obj = model_forward(model, model_inputs, hidden_states) - - past_key_values = output_obj.get('past_key_values', None) - state = self.mb_manager.step(model_inputs, pre_stage_out, past_key_values) - if self.stage_manager.is_last_stage(): - new_token = self.get_token_id(output_obj['hidden_states']) - self.mb_manager.add_new_tokens(new_token) - if state is not Status.DONE: - self.comm.send_forward({'new_token': new_token}) + inputs_dict = None + new_token = None + output_dict = None + + # First stage and in PREFILL phase, just load the inputs + if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: + inputs_dict = self.load_micro_batch() + output_dict = model_forward(model, inputs_dict, None) + self.mb_manager.step(inputs_dict, output_dict, None) + # In GENERATE phase else: - self.comm.send_forward({'hidden_states': output_obj['hidden_states']}) - output_sequence.extend(self.mb_manager.export_new_tokens()) + # Get hidden_states from previous stage + hidden_states = self.comm.recv_forward() + if self.stage_manager.is_first_stage(): + # First just generate a new token + assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + logits = model_forward(model, None, hidden_states) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {output.keys()}" + new_token = self._get_token_id(logits['logits']) + self.mb_manager.step(None, None, new_token) + # If the current micro batch is not DONE, go through blocks + if self.mb_manager.cur_state is Status.GENERATE: + inputs_dict = self._prepare_inputs_for_new_token(new_token) + output_dict = model_forward(model, inputs_dict, None) + self.mb_manager.step(inputs_dict, output_dict, None) + else: + assert hidden_states is not None, "When not first stage, the hidden states should not be None" + inputs_dict = self._prepare_inputs_for_interval_stage() + output_dict = model_forward(model, inputs_dict, hidden_states) + self.mb_manager.step(inputs_dict, output_dict, None) + + # Current microbatch is not DONE, send hidden_state to next stage + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state is Status.GENERATE: + self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) + + self.mb_manager.next() + + # All microbatch in current round is DONE + if self.stage_manager.is_first_stage(): + output_sequence.extend(self.mb_manager.export_new_tokens()) + self.mb_manager.clear() return output_sequence diff --git a/tests/test_generate/test_pipeline_infer.py b/tests/test_generate/test_pipeline_infer.py index 5bc2f1857536..47cf9e78d138 100644 --- a/tests/test_generate/test_pipeline_infer.py +++ b/tests/test_generate/test_pipeline_infer.py @@ -34,7 +34,7 @@ def pipeline_inference_test(pp_size, new_length, micro_batch_size): new_length=new_length, micro_batch_size=micro_batch_size) output = engine.inference([inputs]) - if dist.get_rank() == dist.get_world_size() - 1: + if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" From 84e76c1d37b2cb2ce0d575f816ab60f158734cdc Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 7 Sep 2023 15:05:09 +0800 Subject: [PATCH 3/7] [Pipeline inference] support llama pipeline inference (#4647) * support llama pipeline inference * remove tie weight operation --- .../inference/pipeline/modeling/llama.py | 231 ++++++++++++++++++ .../pipeline/policy/llama_ppinfer.py | 50 ++++ 2 files changed, 281 insertions(+) create mode 100644 colossalai/inference/pipeline/modeling/llama.py create mode 100644 colossalai/inference/pipeline/policy/llama_ppinfer.py diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py new file mode 100644 index 000000000000..eeda96df25fd --- /dev/null +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -0,0 +1,231 @@ +from typing import List, Optional, Tuple + +import torch +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class LlamaPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + ''' + + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, + past_key_values_length) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + if past_key_values is None: + past_key_values = tuple([None] * (end_idx - start_idx + 1)) + + for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): + decoder_layer = self.layers[idx] + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + # always return dict for imediate stage + return {'hidden_states': hidden_states, 'past_key_values': next_cache} + + def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {'logits': lm_logits} + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + return outputs diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py new file mode 100644 index 000000000000..bb359de0bb6f --- /dev/null +++ b/colossalai/inference/pipeline/policy/llama_ppinfer.py @@ -0,0 +1,50 @@ +from functools import partial +from typing import Callable, Dict, List, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn import Module + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.policies.llama import LlamaPolicy + +from ..modeling.llama import LlamaPipelineForwards + + +class LlamaForCausalLMPipelinePolicy(LlamaPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward(model_cls=LlamaForCausalLM, + new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_first_stage(): + held_layers.append(self.model.lm_head) + return held_layers From 65d300f42acd15076286d54ced2276fc8f2fd2e4 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 21 Sep 2023 16:25:45 +0800 Subject: [PATCH 4/7] [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708) * add benchmark verbose * fix export tokens * fix benchmark verbose * add P2POp style to do p2p communication * modify schedule as p2p type when ppsize is 2 * remove unused code and add docstring --- colossalai/inference/pipeline/engine.py | 3 +- .../inference/pipeline/microbatch_manager.py | 7 +- colossalai/pipeline/p2p.py | 89 ++++++++ colossalai/pipeline/schedule/generate.py | 197 +++++++++++++++++- 4 files changed, 287 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 9236ee0a7bff..39366d2d69da 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -60,6 +60,7 @@ def __init__( new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, + verbose: bool = False, # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, do_sample: bool = False, @@ -71,7 +72,7 @@ def __init__( self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) - self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager) + self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) self.model = pp_model or self._shardformer(model, model_policy) def inference(self, input_list): diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 7f4b14c17748..b6b008442cfd 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -10,6 +10,7 @@ class Status(Enum): PREFILL = 1 GENERATE = 2 DONE = 3 + COOLDOWN = 4 class MicroBatchDescription(): @@ -52,6 +53,8 @@ def state(self): # TODO: add the condition for early stopping if self.cur_length == self.target_length: return Status.DONE + elif self.cur_length == self.target_length - 1: + return Status.COOLDOWN else: return Status.GENERATE @@ -184,7 +187,9 @@ def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, ne return self.cur_state def export_new_tokens(self): - new_tokens_list = [i.new_tokens[0].tolist() for i in self.mb_descrption_buffer.values()] + new_tokens_list = [] + for i in self.mb_descrption_buffer.values(): + new_tokens_list.extend(i.new_tokens.tolist()) return new_tokens_list def is_micro_batch_done(self): diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index c69bbe6e8521..e18e7295a947 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -160,6 +160,81 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: return object_list[0] +def _p2p_comm_shape( + tensor_send_next: torch.Tensor, + recv_prev: bool, + peer: int, + group: ProcessGroup, +): + send_next_shape = None + recv_prev_shape = None + + if tensor_send_next is not None: + send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64) + if recv_prev: + recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + + ops = [] + if send_next_shape is not None: + send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group) + ops.append(send_next_op) + if recv_prev_shape is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + recv_prev_shape, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + if recv_prev_shape is not None: + recv_prev_shape = recv_prev_shape.tolist() + + return recv_prev_shape + + +def _p2p_comm( + tensor_send_next: torch.Tensor, + recv_pre: bool, + peer: int, + group: ProcessGroup, + comm_type: torch.dtype = torch.float32, +): + tensor_recv_prev = None + recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) + if recv_pre: + tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type) + + ops = [] + if tensor_send_next is not None: + send_next_op = dist.P2POp( + dist.isend, + tensor_send_next, + peer=peer, + group=group, + ) + ops.append(send_next_op) + + if tensor_recv_prev is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + tensor_recv_prev, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + return tensor_recv_prev + + class PipelineP2PCommunication: def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -221,3 +296,17 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + + def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None: + """ + Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if peer is None: + peer = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer)) + return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index a6ca76a83468..85feebea6e6e 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -1,5 +1,6 @@ +import time from functools import partial -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Iterable, List, Optional, Union import torch import torch.cuda @@ -15,6 +16,17 @@ from .base import PipelineSchedule +class ActionIntervalBuffer(): + + def __int__(self): + self.hidden_states = None + self.new_token = None + + def clear(self): + self.hidden_states = None + self.new_token = None + + class GenerateSchedule(PipelineSchedule): ''' GenerateSchedule is a class that handles the pipeline parallel inference. @@ -22,11 +34,12 @@ class GenerateSchedule(PipelineSchedule): this schedule, the out for each encoding progress is on rank0. Args: - stage_manager (PipelineStageManager): Pipeline stage manager. - mb_manager (MicroBatchManager): Micro batch manager. + stage_manager (`PipelineStageManager`): Pipeline stage manager. + mb_manager (`MicroBatchManager`): Micro batch manager. + verbose (bool): Whether to verbose the information of the pipeline. ''' - def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager) -> None: + def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: super().__init__(stage_manager) self.comm = PipelineP2PCommunication(stage_manager) self.mb_manager = mb_manager @@ -35,6 +48,8 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None + self.verbose = verbose + self.action_interval_buffer = ActionIntervalBuffer() def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -95,9 +110,164 @@ def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) return input_ids - @torch.no_grad() + def _recv_pre_stage(self) -> Any: + ''' + Receive the output from previous stage + + Returns: + Any: The output from previous stage + ''' + if self.stage_manager.num_stages == 2: + return self.comm.p2p_recv() + return self.comm.recv_forward() + + def LoadStageAction(self, model: Module) -> None: + """ + In this action, 1.load micro_batch 2.do the forward 3.step to update + """ + inputs_dict = self.load_micro_batch() + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def GenTokenAction(self, model: Module): + """ + In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update + """ + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + hidden_states = {'hidden_states': hidden_states} + logits = model_forward(model, None, hidden_states) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits['logits']) + + self.mb_manager.step(None, None, new_token) + self.action_interval_buffer.new_token = new_token + self.action_interval_buffer.hidden_states = None + + def HeadEncodingAction(self, model: Module): + """ + In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update + """ + new_token = self.action_interval_buffer.new_token + assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" + inputs_dict = self._prepare_inputs_for_new_token(new_token) + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def BodyEncodingAction(self, model: Module): + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When not first stage, the hidden states should not be None" + inputs_dict = self._prepare_inputs_for_interval_stage() + hidden_states = {'hidden_states': hidden_states} + output_dict = model_forward(model, inputs_dict, hidden_states) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def CommAction(self, recv_pre: bool) -> torch.Tensor: + """ + In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage + """ + hidden_states = self.action_interval_buffer.hidden_states + ret = self.comm.p2p_communicate(hidden_states, recv_pre) + + self.action_interval_buffer.hidden_states = ret + + def genAction(self, model: Module): + """ + In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done + at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will + generate a sequence action for current state, and do the action one by one. + + Args: + model (Module): Model to be run. + + Returns: + List[Callable]: A list of action, each action is a callable function, and it will be called in order. + """ + actions = [] + if self.stage_manager.is_first_stage(): + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self.CommAction, False)) + actions.append(partial(self.LoadStageAction, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.GenTokenAction, model)) + actions.append(partial(self.HeadEncodingAction, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.GenTokenAction, model)) + # other stage + else: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.BodyEncodingAction, model)) + + return actions + + def verbose_info(self, timestamps: List): + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1, + len(timestamp) - 1)) / (len(timestamp) - 2)) + end2end.append(timestamp[-1] - timestamp[0]) + print(f"Average prefill time: {sum(prefill)/len(prefill)}") + print(f"Average encode time: {sum(encoder)/len(encoder)}") + print(f"Average end2end time: {sum(end2end)/len(end2end)}") + def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: - """Forward one step of the pipeline + if self.stage_manager.num_stages == 2: + return self.generate_step_p2p(model, data_iter) + else: + return self.generate_step_broadcast(model, data_iter) + + @torch.no_grad() + def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be + blocked, so we use `P2POp` asynchronous communication method. + + Args: + model (Module): Model to be run. + data_iter (Iterable): Data iterator. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + output_sequence = [] + self.load_batch(data_iter) + model.eval() + + whole_timestamp = [] + + #run by round + for _ in range(self.round): + self.action_interval_buffer.clear() + while self.mb_manager.is_micro_batch_done() is False: + actions = self.genAction(model) + for action in actions: + action() + self.mb_manager.next() + # All microbatch in current round is DONE + if self.stage_manager.is_first_stage(): + output_sequence.extend(self.mb_manager.export_new_tokens()) + else: + self.CommAction(False) + self.mb_manager.clear() + + return output_sequence + + @torch.no_grad() + def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline Args: model (Module): Model to be run. @@ -110,8 +280,11 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso self.load_batch(data_iter) model.eval() + whole_timestamp = [] # run by round for _ in range(self.round): + timestampes = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -120,6 +293,8 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # First stage and in PREFILL phase, just load the inputs if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + timestampes[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) # In GENERATE phase @@ -130,7 +305,9 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # First just generate a new token assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {output.keys()}" + if self.verbose and self.stage_manager.is_first_stage(): + timestampes[self.mb_manager.idx].append(time.time()) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks @@ -154,4 +331,10 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(timestampes) + + if self.verbose and self.stage_manager.is_first_stage(): + self.verbose_info(whole_timestamp) + return output_sequence From 121ab526d038f089b4ad14440c4dea858741c957 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 27 Sep 2023 14:29:27 +0800 Subject: [PATCH 5/7] [Pipeline inference] Refactor code, add docsting, fix bug (#4790) * add benchmark script * update argparse * fix fp16 load * refactor code style * add docstring * polish code * fix test bug --- .../inference/pipeline/benchmark/benchmark.py | 112 ++++++++++++++++++ .../inference/pipeline/benchmark/run.sh | 50 ++++++++ colossalai/inference/pipeline/engine.py | 25 ++-- colossalai/pipeline/p2p.py | 37 +++--- colossalai/pipeline/schedule/generate.py | 63 +++++----- 5 files changed, 230 insertions(+), 57 deletions(-) create mode 100644 colossalai/inference/pipeline/benchmark/benchmark.py create mode 100644 colossalai/inference/pipeline/benchmark/run.sh diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py new file mode 100644 index 000000000000..97dfc6336bea --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -0,0 +1,112 @@ +import torch +import torch.distributed as dist +import transformers + +import colossalai +import time +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy +import argparse +GIGABYTE = 1024 ** 3 +MEGABYTE = 1024 * 1024 + +colossalai.launch_from_torch(config={}) + +def data_gen(batch_size: int=4, seq_len: int=512): + input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) + attention_mask = torch.ones((1, seq_len), dtype=torch.int32) + data = dict(input_ids=input_ids, attention_mask=attention_mask) + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = batch_size + data[k] = v.to('cuda').repeat(*new_shape) + return data + +def print_details_info(timestamps, model_config, args, whole_end2end): + if dist.get_rank() == 0: + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) + end2end.append(timestamp[-1] - timestamp[0]) + print(whole_end2end) + with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: + mb_avg_end2end = sum(end2end)/len(end2end) + mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) + whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size + if args.dtype in ['fp16','bf16']: + num_bytes = 2 + else: + num_bytes = 4 + + f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) + f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000)) + f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) + f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12)) + f.write("----------------------------------------------------------\n") + + + if torch.cuda.is_available(): + current_device = torch.cuda.current_device() + + # free memory and the total available memory in bytes + global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() + memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.memory_reserved() + max_memory_reserved = torch.cuda.max_memory_reserved() + with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + f.write( + f"\nCurrently using GPU: {current_device}\n" + f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" + f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" + f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" + f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" + ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='toy', help='the size of model') + parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size') + parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length') + parser.add_argument('--new_length', type=int, default=4, help='new tokens length') + parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') + parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') + parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log') + parser.add_argument('--dtype', type=str, default='fp16', help='data type') + args = parser.parse_args() + + if args.model == 'toy': + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) + elif args.model == '7b': + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf')) + elif args.model == '13b': + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf')) + else: + raise NotImplementedError + + + engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + data = data_gen(args.batch_size, args.seq_len) + + torch.cuda.synchronize() + whole_end2end = time.time() + output, timestamps = engine.inference([data]) + torch.cuda.synchronize() + whole_end2end = time.time() - whole_end2end + + print_details_info(timestamps, model.config, args, whole_end2end) + diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh new file mode 100644 index 000000000000..7d8da858692f --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -0,0 +1,50 @@ +script_dir=$(cd "$(dirname "$0")" && pwd) +cd "${script_dir}" + +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp32, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16 32; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 13b, fp16, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 39366d2d69da..048ead2bccda 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,23 +1,15 @@ -import re -from functools import partial -from types import MethodType -from typing import Callable, List, Optional, Set +from typing import Callable, List, Optional, Set, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.schedule.generate import GenerateSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.base_policy import Policy from .microbatch_manager import MicroBatchManager -from .policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy -from .utils import get_suffix_name, set_tensors_to_none class PPInferEngine: @@ -54,6 +46,7 @@ class PPInferEngine: def __init__( self, pp_size: int, + dtype: str = 'fp16', pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, @@ -72,12 +65,22 @@ def __init__( self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) + self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) + + assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" + if dtype == 'fp16': + model.half() + elif dtype == 'bf16': + model.to(torch.bfloat16) self.model = pp_model or self._shardformer(model, model_policy) def inference(self, input_list): - out = self.schedule.generate_step(self.model, iter(input_list)) - return out + out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) + if self.verbose: + return out, timestamp + else: + return out def _shardformer(self, model, model_policy): shardconfig = ShardConfig( diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index e18e7295a947..67e198ca0347 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -160,12 +160,27 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: return object_list[0] -def _p2p_comm_shape( +def _p2p_comm( tensor_send_next: torch.Tensor, recv_prev: bool, peer: int, group: ProcessGroup, + comm_dtype: torch.dtype = torch.float16, ): + """ + Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. + + Agrs: + tensor_send_next (torch.Tensor): tensor to be sent to next stage + recv_prev (bool): whether to receive tensor from previous stage + peer (int): rank of the peer + group (ProcessGroup): process group + comm_dtype (torch.dtype): dtype of the tensor to be sent + + Returns: + torch.Tensor: tensor received from previous stage + """ + # send and recv shape send_next_shape = None recv_prev_shape = None @@ -195,20 +210,10 @@ def _p2p_comm_shape( if recv_prev_shape is not None: recv_prev_shape = recv_prev_shape.tolist() - return recv_prev_shape - - -def _p2p_comm( - tensor_send_next: torch.Tensor, - recv_pre: bool, - peer: int, - group: ProcessGroup, - comm_type: torch.dtype = torch.float32, -): + # send and recv data tensor_recv_prev = None - recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) - if recv_pre: - tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type) + if recv_prev: + tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) ops = [] if tensor_send_next is not None: @@ -297,7 +302,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) - def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None: + def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -308,5 +313,5 @@ def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) if peer is None: peer = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer)) + recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype) return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 85feebea6e6e..1a961d3036b8 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -17,6 +17,10 @@ class ActionIntervalBuffer(): + """ + The buffer to save the interval hidden states and new token for stage to use. + + """ def __int__(self): self.hidden_states = None @@ -28,7 +32,7 @@ def clear(self): class GenerateSchedule(PipelineSchedule): - ''' + """ GenerateSchedule is a class that handles the pipeline parallel inference. In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in this schedule, the out for each encoding progress is on rank0. @@ -37,7 +41,7 @@ class GenerateSchedule(PipelineSchedule): stage_manager (`PipelineStageManager`): Pipeline stage manager. mb_manager (`MicroBatchManager`): Micro batch manager. verbose (bool): Whether to verbose the information of the pipeline. - ''' + """ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: super().__init__(stage_manager) @@ -48,8 +52,10 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None - self.verbose = verbose self.action_interval_buffer = ActionIntervalBuffer() + self.verbose = verbose + self.timestamps = None + self.comm_dtype = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -126,6 +132,9 @@ def LoadStageAction(self, model: Module) -> None: In this action, 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) @@ -139,6 +148,9 @@ def GenTokenAction(self, model: Module): assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" hidden_states = {'hidden_states': hidden_states} logits = model_forward(model, None, hidden_states) + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) @@ -173,7 +185,7 @@ def CommAction(self, recv_pre: bool) -> torch.Tensor: In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage """ hidden_states = self.action_interval_buffer.hidden_states - ret = self.comm.p2p_communicate(hidden_states, recv_pre) + ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype) self.action_interval_buffer.hidden_states = ret @@ -208,20 +220,6 @@ def genAction(self, model: Module): return actions - def verbose_info(self, timestamps: List): - prefill = [] - encoder = [] - end2end = [] - for timestamp in timestamps: - prefill.append(timestamp[1] - timestamp[0]) - encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1, - len(timestamp) - 1)) / (len(timestamp) - 2)) - end2end.append(timestamp[-1] - timestamp[0]) - print(f"Average prefill time: {sum(prefill)/len(prefill)}") - print(f"Average encode time: {sum(encoder)/len(encoder)}") - print(f"Average end2end time: {sum(end2end)/len(end2end)}") - def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: if self.stage_manager.num_stages == 2: return self.generate_step_p2p(model, data_iter) @@ -244,11 +242,14 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T output_sequence = [] self.load_batch(data_iter) model.eval() + self.comm_dtype = model.dtype whole_timestamp = [] #run by round for _ in range(self.round): + self.timestamps = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self.genAction(model) @@ -261,8 +262,10 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T else: self.CommAction(False) self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(self.timestamps) - return output_sequence + return output_sequence, whole_timestamp @torch.no_grad() def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: @@ -283,8 +286,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t whole_timestamp = [] # run by round for _ in range(self.round): - timestampes = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -294,7 +297,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() if self.verbose and self.stage_manager.is_first_stage(): - timestampes[self.mb_manager.idx].append(time.time()) + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) # In GENERATE phase @@ -306,12 +310,13 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): - timestampes[self.mb_manager.idx].append(time.time()) + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks - if self.mb_manager.cur_state is Status.GENERATE: + if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) @@ -322,7 +327,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t self.mb_manager.step(inputs_dict, output_dict, None) # Current microbatch is not DONE, send hidden_state to next stage - if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state is Status.GENERATE: + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE, + Status.COOLDOWN): self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) self.mb_manager.next() @@ -332,9 +338,6 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() if self.verbose and self.stage_manager.is_first_stage(): - whole_timestamp.extend(timestampes) - - if self.verbose and self.stage_manager.is_first_stage(): - self.verbose_info(whole_timestamp) + whole_timestamp.extend(self.timestamps) - return output_sequence + return output_sequence, whole_timestamp From 9418c07d9f9bad59f2b509da4b2416a0a14ca3cb Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:57:03 +0800 Subject: [PATCH 6/7] [Pipeline inference] Add pipeline inference docs (#4817) * add readme doc * add a ico * Add performance * update table of contents --- colossalai/inference/pipeline/README.md | 84 +++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 colossalai/inference/pipeline/README.md diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md new file mode 100644 index 000000000000..a90d5d6da182 --- /dev/null +++ b/colossalai/inference/pipeline/README.md @@ -0,0 +1,84 @@ +# 🐳 Pipeline Inference + +## Table of Contents +- [💡 Introduction](#introduction) +- [🔗 Design](#design) +- [🔨 Usage](#usage) + - [Example](#example) + - [Quick start](#quick-start) +- [📊 Performance](#performance) + +## Introduction + +`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline. + +## Design + +Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py). + +1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks: + - Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`. + - Run the pipeline inference model. + +2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks: + - Record each micro-batch information, like generated new tokens and kvcache. + - Record each micro-batch inference state, like prefill, generate or done. + - Update the micro-batch information. + +3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`. + +## Usage + +### Example +```python +from colossalai.pipeline import PPInferEngine +# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example. +model = LlamaForCausalLM.from_pretrained('/path/to/model') +inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt") +engine = PPInferEngine( + pp_size=2, + dtype='fp16', + micro_batch_size=1, + new_length=10, + model=model, + model_policy=LlamaForCausalLMPipelinePolicy()) + +output = engine.inference([inputs]) + +``` + +### Quick start +```shell +cd benchmark +sh run.sh +``` + +## Performance + +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G. + +### Llama Throughput(tokens/s) + +#### 7b, fp16 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| +| :---: | :---: | :---: | :---: | :---: | :---: | :---:| +| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM | +| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | +| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 | +| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM | + +#### 7b, fp32 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | +| :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 | +| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM | +| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 | +| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM | + +#### 13b, fp16 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | +| :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 | +| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM | +| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 | +| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM | From 139adda66d93b44ea1af110352b9d3e6916d433c Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 9 Oct 2023 15:35:12 +0800 Subject: [PATCH 7/7] refactor code (#4873) --- colossalai/pipeline/schedule/generate.py | 34 +++++++++---------- .../test_pipeline_infer.py | 0 2 files changed, 17 insertions(+), 17 deletions(-) rename tests/{test_generate => test_infer}/test_pipeline_infer.py (100%) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 1a961d3036b8..8f6acd5fcf4b 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -127,7 +127,7 @@ def _recv_pre_stage(self) -> Any: return self.comm.p2p_recv() return self.comm.recv_forward() - def LoadStageAction(self, model: Module) -> None: + def _load_stage_action(self, model: Module) -> None: """ In this action, 1.load micro_batch 2.do the forward 3.step to update """ @@ -140,7 +140,7 @@ def LoadStageAction(self, model: Module) -> None: self.mb_manager.step(inputs_dict, output_dict, None) self.action_interval_buffer.hidden_states = output_dict['hidden_states'] - def GenTokenAction(self, model: Module): + def _gen_token_action(self, model: Module): """ In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update """ @@ -158,7 +158,7 @@ def GenTokenAction(self, model: Module): self.action_interval_buffer.new_token = new_token self.action_interval_buffer.hidden_states = None - def HeadEncodingAction(self, model: Module): + def _head_encoding_action(self, model: Module): """ In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update """ @@ -170,7 +170,7 @@ def HeadEncodingAction(self, model: Module): self.mb_manager.step(inputs_dict, output_dict, None) self.action_interval_buffer.hidden_states = output_dict['hidden_states'] - def BodyEncodingAction(self, model: Module): + def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" inputs_dict = self._prepare_inputs_for_interval_stage() @@ -180,7 +180,7 @@ def BodyEncodingAction(self, model: Module): self.mb_manager.step(inputs_dict, output_dict, None) self.action_interval_buffer.hidden_states = output_dict['hidden_states'] - def CommAction(self, recv_pre: bool) -> torch.Tensor: + def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage """ @@ -189,7 +189,7 @@ def CommAction(self, recv_pre: bool) -> torch.Tensor: self.action_interval_buffer.hidden_states = ret - def genAction(self, model: Module): + def _gen_action(self, model: Module): """ In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will @@ -204,19 +204,19 @@ def genAction(self, model: Module): actions = [] if self.stage_manager.is_first_stage(): if self.mb_manager.cur_state is Status.PREFILL: - actions.append(partial(self.CommAction, False)) - actions.append(partial(self.LoadStageAction, model)) + actions.append(partial(self._comm_action, False)) + actions.append(partial(self._load_stage_action, model)) elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE: - actions.append(partial(self.CommAction, True)) - actions.append(partial(self.GenTokenAction, model)) - actions.append(partial(self.HeadEncodingAction, model)) + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._gen_token_action, model)) + actions.append(partial(self._head_encoding_action, model)) elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN: - actions.append(partial(self.CommAction, True)) - actions.append(partial(self.GenTokenAction, model)) + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._gen_token_action, model)) # other stage else: - actions.append(partial(self.CommAction, True)) - actions.append(partial(self.BodyEncodingAction, model)) + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._body_encoding_action, model)) return actions @@ -252,7 +252,7 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T ] if self.verbose and self.stage_manager.is_first_stage() else None self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: - actions = self.genAction(model) + actions = self._gen_action(model) for action in actions: action() self.mb_manager.next() @@ -260,7 +260,7 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) else: - self.CommAction(False) + self._comm_action(False) self.mb_manager.clear() if self.verbose and self.stage_manager.is_first_stage(): whole_timestamp.extend(self.timestamps) diff --git a/tests/test_generate/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py similarity index 100% rename from tests/test_generate/test_pipeline_infer.py rename to tests/test_infer/test_pipeline_infer.py