Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pipeline import PPInferEngine

__all__ = ['PPInferEngine']
84 changes: 84 additions & 0 deletions colossalai/inference/pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# 🐳 Pipeline Inference

## Table of Contents
- [💡 Introduction](#introduction)
- [🔗 Design](#design)
- [🔨 Usage](#usage)
- [Example](#example)
- [Quick start](#quick-start)
- [📊 Performance](#performance)

## Introduction

`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline.

## Design

Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).

1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
- Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
- Run the pipeline inference model.

2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
- Record each micro-batch information, like generated new tokens and kvcache.
- Record each micro-batch inference state, like prefill, generate or done.
- Update the micro-batch information.

3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`.

## Usage

### Example
```python
from colossalai.pipeline import PPInferEngine
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
model = LlamaForCausalLM.from_pretrained('/path/to/model')
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
engine = PPInferEngine(
pp_size=2,
dtype='fp16',
micro_batch_size=1,
new_length=10,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy())

output = engine.inference([inputs])

```

### Quick start
```shell
cd benchmark
sh run.sh
```

## Performance

We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.

### Llama Throughput(tokens/s)

#### 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |

#### 7b, fp32
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |

#### 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |
3 changes: 3 additions & 0 deletions colossalai/inference/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .engine import PPInferEngine

__all__ = ['PPInferEngine']
112 changes: 112 additions & 0 deletions colossalai/inference/pipeline/benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import torch.distributed as dist
import transformers

import colossalai
import time
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
import argparse
GIGABYTE = 1024 ** 3
MEGABYTE = 1024 * 1024

colossalai.launch_from_torch(config={})

def data_gen(batch_size: int=4, seq_len: int=512):
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
data = dict(input_ids=input_ids, attention_mask=attention_mask)
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = batch_size
data[k] = v.to('cuda').repeat(*new_shape)
return data

def print_details_info(timestamps, model_config, args, whole_end2end):
if dist.get_rank() == 0:
prefill = []
encoder = []
end2end = []
for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0])
encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
end2end.append(timestamp[-1] - timestamp[0])
print(whole_end2end)
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
mb_avg_end2end = sum(end2end)/len(end2end)
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
if args.dtype in ['fp16','bf16']:
num_bytes = 2
else:
num_bytes = 4

f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
f.write("----------------------------------------------------------\n")


if torch.cuda.is_available():
current_device = torch.cuda.current_device()

# free memory and the total available memory in bytes
global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info()
memory_allocated = torch.cuda.memory_allocated()
max_memory_allocated = torch.cuda.max_memory_allocated()
memory_reserved = torch.cuda.memory_reserved()
max_memory_reserved = torch.cuda.max_memory_reserved()
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
f.write(
f"\nCurrently using GPU: {current_device}\n"
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n"
f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n"
f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n"
f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n"
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='toy', help='the size of model')
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
args = parser.parse_args()

if args.model == 'toy':
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
elif args.model == '7b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
elif args.model == '13b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
else:
raise NotImplementedError


engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)
data = data_gen(args.batch_size, args.seq_len)

torch.cuda.synchronize()
whole_end2end = time.time()
output, timestamps = engine.inference([data])
torch.cuda.synchronize()
whole_end2end = time.time() - whole_end2end

print_details_info(timestamps, model.config, args, whole_end2end)

50 changes: 50 additions & 0 deletions colossalai/inference/pipeline/benchmark/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
script_dir=$(cd "$(dirname "$0")" && pwd)
cd "${script_dir}"

# 7b, fp32, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=1024 \
--new_length=128 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done

# 7b, fp32, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16 32; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=512 \
--new_length=512 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done

# 7b, fp32, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=1024 \
--new_length=128 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done

# 13b, fp16, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=512 \
--new_length=512 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done
98 changes: 98 additions & 0 deletions colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Callable, List, Optional, Set, Union

import torch
import torch.nn as nn

from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy

from .microbatch_manager import MicroBatchManager


class PPInferEngine:
'''
PPInferEngine is a class that handles the pipeline parallel inference.

Args:
pp_size (int): the number of pipeline stages.
pp_model (`nn.Module`): the model already in pipeline parallelism style.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.

Example:

```python
from colossalai.ppinference import PPInferEngine
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
# assume the model is infered with 4 pipeline stages
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})

input = ["Hello, my dog is cute, and I like"]
tokenized_input = tokenizer(input, return_tensors='pt')
output = engine.inference([tokenized_input])
```

'''

def __init__(
self,
pp_size: int,
dtype: str = 'fp16',
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
self.pp_size = pp_size
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
micro_batch_buffer_size or pp_size)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == 'fp16':
model.half()
elif dtype == 'bf16':
model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy)

def inference(self, input_list):
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
if self.verbose:
return out, timestamp
else:
return out

def _shardformer(self, model, model_policy):
shardconfig = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
Loading