Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pipeline import PPInferEngine

__all__ = ['PPInferEngine']
3 changes: 3 additions & 0 deletions colossalai/inference/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .engine import PPInferEngine

__all__ = ['PPInferEngine']
93 changes: 93 additions & 0 deletions colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
from functools import partial
from types import MethodType
from typing import Callable, List, Optional, Set

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn

from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.base_policy import Policy

from .microbatch_manager import MicroBatchManager
from .policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
from .utils import get_suffix_name, set_tensors_to_none


class PPInferEngine:
'''
PPInferEngine is a class that handles the pipeline parallel inference.

Args:
pp_size (int): the number of pipeline stages.
pp_model (`nn.Module`): the model already in pipeline parallelism style.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.

Example:

```python
from colossalai.ppinference import PPInferEngine
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
# assume the model is infered with 4 pipeline stages
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})

input = ["Hello, my dog is cute, and I like"]
tokenized_input = tokenizer(input, return_tensors='pt')
output = engine.inference([tokenized_input])
```

'''

def __init__(
self,
pp_size: int,
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
self.pp_size = pp_size
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(new_length, micro_batch_size, micro_batch_buffer_size or pp_size)
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager)
self.model = pp_model or self._shardformer(model, model_policy)

def inference(self, input_list):
out = self.schedule.generate_step(self.model, iter(input_list))
return out

def _shardformer(self, model, model_policy):
shardconfig = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
150 changes: 150 additions & 0 deletions colossalai/inference/pipeline/microbatch_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from enum import Enum
from typing import Dict

import torch

__all__ = 'MicroBatchManager'


class Status(Enum):
PREFILL = 1
GENERATE = 2
DONE = 3


class MicroBatchDescription():

def __init__(
self,
mb_inputs: Dict[str, torch.Tensor],
interval_inputs: Dict[str, torch.Tensor],
new_length: int,
) -> None:
if mb_inputs is not None:
assert mb_inputs.get('input_ids') is not None and mb_inputs.get('attention_mask') is not None
self.mb_length = mb_inputs['input_ids'].shape[-1]
self.attn_mask = mb_inputs['attention_mask']
self.input_ids = mb_inputs['input_ids']

elif interval_inputs is not None:
assert interval_inputs.get('hidden_states') is not None
self.mb_length = interval_inputs['hidden_states'].shape[-2]
else:
raise ValueError('mb_inputs and interval_inputs can not be None at the same time')

self.target_length = self.mb_length + new_length
self.kv_cache = ()

def update(self, kv_cache):
self.kv_cache = kv_cache

@property
def cur_length(self):
"""
Return the current sequnence length of micro batch, when there is no kv_cache, the length is mb_length,
otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1

"""
if len(self.kv_cache) == 0:
return self.mb_length
return self.kv_cache[0][0].shape[-2] + 1

@property
def state(self):
"""
Return the state of current micro batch, when current length is equal to target length,
the state is DONE, otherwise GENERATE

"""
if self.cur_length == self.target_length:
return Status.DONE
else:
return Status.GENERATE


class MicroBatchManager():
'''
MicroBatchManager is a class that manages the micro batch.

Args:
new_length (int): the new length of the input sequence.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
'''

def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
self.new_length = new_length
self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size
self.mb_descrption_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0

def _add_descrption(self, mb_inputs: Dict[str, torch.Tensor], inter_inputs: Dict[str, torch.Tensor]):
self.mb_descrption_buffer[self.idx] = MicroBatchDescription(mb_inputs, inter_inputs, self.new_length)

def _update_descrption(self, present_kv):
self.mb_descrption_buffer[self.idx].update(present_kv)

def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)

def step(self, mb_inputs=None, inter_inputs=None, present_kv=None):
"""
Update the state if microbatch manager

Args:
mb_inputs (int, optional): The input of first stage when in prefill, should be a dict like {'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}.
inter_inputs ([type], optional): The input of intermediate stage (the output of previous stage), should be a dict like {'hidden_state': torch.Tensor}.
present_kv ([type], optional): The kvcache of current microbatch in current stage.
"""
if self.mb_descrption_buffer.get(self.idx) is None:
self._add_descrption(mb_inputs, inter_inputs)
self._update_descrption(present_kv)
state = self.cur_state
self.next()
return state

def next(self):
self.idx = (self.idx + 1) % self.buffer_size

def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0:
return False
for mb in self.mb_descrption_buffer.values():
if mb.state != Status.DONE:
return False
self.mb_descrption_buffer.clear()
return True

def add_new_tokens(self, new_token):
if self.idx not in self.new_tokens_buffer:
self.new_tokens_buffer[self.idx] = new_token
else:
self.new_tokens_buffer[self.idx] = torch.cat([self.new_tokens_buffer[self.idx], new_token], dim=-1)

def export_new_tokens(self):
list = [item.tolist() for item in self.new_tokens_buffer.values()]
flat_list = [item for sublist in list for item in sublist]
self.new_tokens_buffer.clear()
return flat_list

@property
def cur_descrption(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx)

@property
def cur_kv_cache(self):
if self.cur_descrption is None:
return None
return self.cur_descrption.kv_cache

@property
def cur_state(self):
"""
Return the state of current micro batch, when current descrption is None, the state is PREFILL

"""
if self.cur_descrption is None:
return Status.PREFILL
return self.cur_descrption.state
Empty file.
Loading