diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py new file mode 100644 index 000000000000..db33ae6fe998 --- /dev/null +++ b/colossalai/inference/__init__.py @@ -0,0 +1,3 @@ +from .pipeline import PPInferEngine + +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py new file mode 100644 index 000000000000..aff4568f7d08 --- /dev/null +++ b/colossalai/inference/pipeline/__init__.py @@ -0,0 +1,3 @@ +from .engine import PPInferEngine + +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py new file mode 100644 index 000000000000..0c748d725d5d --- /dev/null +++ b/colossalai/inference/pipeline/engine.py @@ -0,0 +1,93 @@ +import re +from functools import partial +from types import MethodType +from typing import Callable, List, Optional, Set + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule.generate import GenerateSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer._utils import getattr_ +from colossalai.shardformer.policies.base_policy import Policy + +from .microbatch_manager import MicroBatchManager +from .policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from .utils import get_suffix_name, set_tensors_to_none + + +class PPInferEngine: + ''' + PPInferEngine is a class that handles the pipeline parallel inference. + + Args: + pp_size (int): the number of pipeline stages. + pp_model (`nn.Module`): the model already in pipeline parallelism style. + model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. + model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + new_length (int): the new length of the input sequence. + early_stopping (bool): whether to stop early. + + Example: + + ```python + from colossalai.ppinference import PPInferEngine + from transformers import GPT2LMHeadModel, GPT2Tokenizer + + model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') + # assume the model is infered with 4 pipeline stages + inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding}) + + input = ["Hello, my dog is cute, and I like"] + tokenized_input = tokenizer(input, return_tensors='pt') + output = engine.inference([tokenized_input]) + ``` + + ''' + + def __init__( + self, + pp_size: int, + pp_model: nn.Module = None, + model: nn.Module = None, + model_policy: Policy = None, + new_length: int = 32, + micro_batch_size: int = 1, + micro_batch_buffer_size: int = None, + # TODO: implement early_stopping, and various gerneration options + early_stopping: bool = False, + do_sample: bool = False, + num_beams: int = 1, + ) -> None: + assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + self.pp_size = pp_size + self.pg_mesh = ProcessGroupMesh(pp_size) + self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) + self.mb_manager = MicroBatchManager(new_length, micro_batch_size, micro_batch_buffer_size or pp_size) + self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager) + self.model = pp_model or self._shardformer(model, model_policy) + + def inference(self, input_list): + out = self.schedule.generate_step(self.model, iter(input_list)) + return out + + def _shardformer(self, model, model_policy): + shardconfig = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model.cuda() diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py new file mode 100644 index 000000000000..f54396bb3747 --- /dev/null +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -0,0 +1,150 @@ +from enum import Enum +from typing import Dict + +import torch + +__all__ = 'MicroBatchManager' + + +class Status(Enum): + PREFILL = 1 + GENERATE = 2 + DONE = 3 + + +class MicroBatchDescription(): + + def __init__( + self, + mb_inputs: Dict[str, torch.Tensor], + interval_inputs: Dict[str, torch.Tensor], + new_length: int, + ) -> None: + if mb_inputs is not None: + assert mb_inputs.get('input_ids') is not None and mb_inputs.get('attention_mask') is not None + self.mb_length = mb_inputs['input_ids'].shape[-1] + self.attn_mask = mb_inputs['attention_mask'] + self.input_ids = mb_inputs['input_ids'] + + elif interval_inputs is not None: + assert interval_inputs.get('hidden_states') is not None + self.mb_length = interval_inputs['hidden_states'].shape[-2] + else: + raise ValueError('mb_inputs and interval_inputs can not be None at the same time') + + self.target_length = self.mb_length + new_length + self.kv_cache = () + + def update(self, kv_cache): + self.kv_cache = kv_cache + + @property + def cur_length(self): + """ + Return the current sequnence length of micro batch, when there is no kv_cache, the length is mb_length, + otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 + + """ + if len(self.kv_cache) == 0: + return self.mb_length + return self.kv_cache[0][0].shape[-2] + 1 + + @property + def state(self): + """ + Return the state of current micro batch, when current length is equal to target length, + the state is DONE, otherwise GENERATE + + """ + if self.cur_length == self.target_length: + return Status.DONE + else: + return Status.GENERATE + + +class MicroBatchManager(): + ''' + MicroBatchManager is a class that manages the micro batch. + + Args: + new_length (int): the new length of the input sequence. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + ''' + + def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + self.new_length = new_length + self.micro_batch_size = micro_batch_size + self.buffer_size = micro_batch_buffer_size + self.mb_descrption_buffer = {} + self.new_tokens_buffer = {} + self.idx = 0 + + def _add_descrption(self, mb_inputs: Dict[str, torch.Tensor], inter_inputs: Dict[str, torch.Tensor]): + self.mb_descrption_buffer[self.idx] = MicroBatchDescription(mb_inputs, inter_inputs, self.new_length) + + def _update_descrption(self, present_kv): + self.mb_descrption_buffer[self.idx].update(present_kv) + + def _remove_descrption(self): + self.mb_descrption_buffer.pop(self.idx) + + def step(self, mb_inputs=None, inter_inputs=None, present_kv=None): + """ + Update the state if microbatch manager + + Args: + mb_inputs (int, optional): The input of first stage when in prefill, should be a dict like {'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}. + inter_inputs ([type], optional): The input of intermediate stage (the output of previous stage), should be a dict like {'hidden_state': torch.Tensor}. + present_kv ([type], optional): The kvcache of current microbatch in current stage. + """ + if self.mb_descrption_buffer.get(self.idx) is None: + self._add_descrption(mb_inputs, inter_inputs) + self._update_descrption(present_kv) + state = self.cur_state + self.next() + return state + + def next(self): + self.idx = (self.idx + 1) % self.buffer_size + + def is_micro_batch_done(self): + if len(self.mb_descrption_buffer) == 0: + return False + for mb in self.mb_descrption_buffer.values(): + if mb.state != Status.DONE: + return False + self.mb_descrption_buffer.clear() + return True + + def add_new_tokens(self, new_token): + if self.idx not in self.new_tokens_buffer: + self.new_tokens_buffer[self.idx] = new_token + else: + self.new_tokens_buffer[self.idx] = torch.cat([self.new_tokens_buffer[self.idx], new_token], dim=-1) + + def export_new_tokens(self): + list = [item.tolist() for item in self.new_tokens_buffer.values()] + flat_list = [item for sublist in list for item in sublist] + self.new_tokens_buffer.clear() + return flat_list + + @property + def cur_descrption(self) -> MicroBatchDescription: + return self.mb_descrption_buffer.get(self.idx) + + @property + def cur_kv_cache(self): + if self.cur_descrption is None: + return None + return self.cur_descrption.kv_cache + + @property + def cur_state(self): + """ + Return the state of current micro batch, when current descrption is None, the state is PREFILL + + """ + if self.cur_descrption is None: + return Status.PREFILL + return self.cur_descrption.state diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/pipeline/modeling/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py new file mode 100644 index 000000000000..773fb2a07899 --- /dev/null +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -0,0 +1,292 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class GPT2PipelineForwards: + ''' + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + ''' + + @staticmethod + def gpt2_model_forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + else: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i, layer_past in zip(range(start_idx, end_idx), past_key_values): + block = self.h[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return {'hidden_states': hidden_states, 'past_key_values': presents} + + @staticmethod + def gpt2_lmhead_model_forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index) + + # If not at the last stage, return hidden_states as in GPT2Model + if not stage_manager.is_last_stage(): + return outputs + + hidden_states = outputs['hidden_states'] + lm_logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return {'hidden_states': lm_logits, 'past_key_values': outputs['past_key_values']} diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py new file mode 100644 index 000000000000..3e4ad30f96ed --- /dev/null +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -0,0 +1,69 @@ +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.policies.gpt2 import GPT2Policy + +from ..modeling.gpt2 import GPT2PipelineForwards + + +class GPT2LMHeadModelPipelinePolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + '''The weights of wte and lm_head are shared.''' + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'GPT2Model': + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py new file mode 100644 index 000000000000..1a6e8a519397 --- /dev/null +++ b/colossalai/inference/pipeline/utils.py @@ -0,0 +1,35 @@ +from typing import List, Optional, Set + +import torch.nn as nn + +from colossalai.shardformer._utils import getattr_, setattr_ + + +def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None: + """ + Set all parameters and buffers of model to None + + Args: + model (nn.Module): The model to set + """ + for module_suffix in include: + set_module = getattr_(model, module_suffix) + for n, p in set_module.named_parameters(): + setattr_(set_module, n, None) + for n, buf in set_module.named_buffers(): + setattr_(set_module, n, None) + setattr_(model, module_suffix, None) + + +def get_suffix_name(suffix: str, name: str): + """ + Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit, + and 'name' when `suffix` is empty. + + Args: + suffix (str): The suffix of the suffix module + name (str): The name of the current module + """ + point = '' if suffix is '' else '.' + suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}' + return suffix_name diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py new file mode 100644 index 000000000000..e12616655d32 --- /dev/null +++ b/colossalai/pipeline/schedule/generate.py @@ -0,0 +1,127 @@ +from functools import partial +from typing import Any, Dict, Iterable, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import get_batch_size, get_micro_batch, model_forward, to_device +from .base import PipelineSchedule + + +class GenerateSchedule(PipelineSchedule): + + def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager) -> None: + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.mb_manager = mb_manager + self.microbatch_size = mb_manager.micro_batch_size + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.num_microbatches: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = 0 + assert self.batch_size % self.microbatch_size == 0, \ + f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" + self.num_microbatches = self.batch_size // self.microbatch_size + self.round = self.num_microbatches // self.stage_manager.num_stages + + def load_micro_batch(self) -> Any: + """Load a micro batch from the current batch. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) + self.microbatch_offset += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def _prepare_stage_inputs(self): + # first stage and in prefill phase + if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: + pre_stage_out = None + model_inputs = self.load_micro_batch() + hidden_states = None + # first stage and in generate phase + elif self.stage_manager.is_first_stage(): + pre_stage_out = self.comm.recv_forward() + model_inputs = self._prepare_next_token(pre_stage_out) + hidden_states = None + # not first stage and in gererate phase + else: + pre_stage_out = self.comm.recv_forward() + model_inputs = { + 'past_key_values': self.mb_manager.cur_kv_cache + } if self.mb_manager.cur_kv_cache is not None else None + hidden_states = pre_stage_out + return pre_stage_out, model_inputs, hidden_states + + def _prepare_next_token(self, inputs: Dict[str, torch.Tensor]): + new_mask = self.mb_manager.cur_descrption.attn_mask + new_mask = torch.cat((new_mask, torch.ones((new_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) + self.mb_manager.cur_descrption.attn_mask = new_mask + past_key_values = self.mb_manager.cur_descrption.kv_cache + + return dict(input_ids=inputs['new_token'], attention_mask=new_mask, past_key_values=past_key_values) + + def get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: + last_hidden_state = hidden_state[:, -1] + input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) + return input_ids + + @torch.no_grad() + def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + + Args: + model (Module): Model to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + output_sequence = [] + self.load_batch(data_iter) + model.eval() + + # run by round + for _ in range(self.round): + state = Status.PREFILL + while self.mb_manager.is_micro_batch_done() is False: + pre_stage_out, model_inputs, hidden_states = self._prepare_stage_inputs() + + output_obj = model_forward(model, model_inputs, hidden_states) + + past_key_values = output_obj.get('past_key_values', None) + state = self.mb_manager.step(model_inputs, pre_stage_out, past_key_values) + if self.stage_manager.is_last_stage(): + new_token = self.get_token_id(output_obj['hidden_states']) + self.mb_manager.add_new_tokens(new_token) + if state is not Status.DONE: + self.comm.send_forward({'new_token': new_token}) + else: + self.comm.send_forward({'hidden_states': output_obj['hidden_states']}) + output_sequence.extend(self.mb_manager.export_new_tokens()) + return output_sequence diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 6ba7dc629958..e1edd77a2d20 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -13,6 +13,7 @@ class PipelineStageManager: Args: pg_mesh (ProcessGroupMesh): Process group mesh. pipeline_axis (int): The axis along which the pipeline is constructed. + is_virtual (bool): Whether to use circle p2p communication, it will make the first and last stage communicate with each other. Attributes: num_stages (int): Number of stages in the pipeline. @@ -25,6 +26,7 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord coord = self.pg_mesh.coordinate() # the prev rank of rank0 is the last rank diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index a94e8d42c78e..4261f5ae26c6 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -51,6 +51,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + torch.cuda.empty_cache() def run_dist(rank, world_size, port): @@ -59,6 +60,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_low_level_zero_checkpointIO(): spawn(run_dist, 2) diff --git a/tests/test_generate/test_pipeline_infer.py b/tests/test_generate/test_pipeline_infer.py new file mode 100644 index 000000000000..5bc2f1857536 --- /dev/null +++ b/tests/test_generate/test_pipeline_infer.py @@ -0,0 +1,63 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers + +import colossalai +from colossalai.inference.pipeline.engine import PPInferEngine +from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + + +def data_gen(): + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +inputs = data_gen() +for k, v in inputs.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 16 + inputs[k] = v.to('cuda').repeat(*new_shape) + + +def pipeline_inference_test(pp_size, new_length, micro_batch_size): + model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) + engine = PPInferEngine(pp_size=pp_size, + model=model, + model_policy=GPT2LMHeadModelPipelinePolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size) + output = engine.inference([inputs]) + if dist.get_rank() == dist.get_world_size() - 1: + assert len(output[0]) == new_length, f"{len(output)}, {new_length}" + + +@parameterize('pp_size', [4]) +@parameterize('new_length', [4, 8, 16]) +@parameterize('micro_batch_size', [1, 4]) +@clear_cache_before_run() +def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): + pipeline_inference_test(pp_size, new_length, micro_batch_size) + torch.cuda.empty_cache() + + +def check_pipeline_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_pipeline_inference_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_pipeline_inference(): + spawn(check_pipeline_inference, nprocs=4) + + +if __name__ == '__main__': + test_pipeline_inference()