Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f2b605c
[infer] Infer/llama demo (#4503)
CjhHa1 Aug 24, 2023
b52362a
[Kernels] add inference token attention kernel (#4505)
Xu-Kai Aug 24, 2023
c3d3f08
[Kernels] add necessary kernels (llama & bloom) for attention forward…
tiandiao123 Aug 24, 2023
049b3d4
combine codes (#4509)
tiandiao123 Aug 24, 2023
b20f424
[feature] add KV cache manager for llama & bloom inference (#4495)
yuanheng-zhao Aug 24, 2023
fb03ff5
[Bug FIx] import llama context ops fix (#4524)
tiandiao123 Aug 28, 2023
2d86602
[Infer] Add TPInferEngine and fix file path (#4532)
yuanheng-zhao Aug 29, 2023
f7afa74
Add Inference test for llama (#4508)
isky-cd Aug 30, 2023
8da6320
[infer] Add Bloom inference policy and replaced methods (#4512)
yuanheng-zhao Aug 30, 2023
27407ec
Revert "[infer] Add Bloom inference policy and replaced methods (#451…
tiandiao123 Aug 30, 2023
7fb971b
[Doc] Add colossal inference doc (#4549)
CjhHa1 Aug 30, 2023
7b26e26
[infer] Add Bloom inference policy and replaced methods (#4553)
yuanheng-zhao Aug 30, 2023
f592598
Fix Bugs In Llama Model Forward (#4550)
isky-cd Aug 30, 2023
230f517
[doc] add colossal inference fig (#4554)
CjhHa1 Aug 30, 2023
57d4aec
[NFC] fix docstring for colossal inference (#4555)
yuanheng-zhao Aug 31, 2023
a5f247a
fix docstring in llama modeling (#4557)
isky-cd Aug 31, 2023
b289497
[Infer] check import vllm (#4559)
Xu-Kai Aug 31, 2023
5ef07d8
[DOC] add installation req (#4561)
tiandiao123 Aug 31, 2023
483b937
[Feature] rms-norm transfer into inference llama.py (#4563)
tiandiao123 Aug 31, 2023
66454d9
[infer] Fix tp inference engine (#4564)
yuanheng-zhao Aug 31, 2023
d7dabb2
reset shardformer llama (#4569)
Xu-Kai Aug 31, 2023
53205ba
[infer] Fix engine - tensors on different devices (#4570)
yuanheng-zhao Aug 31, 2023
4705179
[codefactor] Feature/colossal inference (#4579)
CjhHa1 Sep 1, 2023
fb6b22f
change coding (#4581)
tiandiao123 Sep 1, 2023
8bd5cdc
[doc] complete README of colossal inference (#4585)
CjhHa1 Sep 1, 2023
594abdf
[doc]update readme (#4586)
tiandiao123 Sep 1, 2023
642c44c
bug fix: fix bus in llama and bloom (#4588)
isky-cd Sep 1, 2023
bbe5367
[BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#…
tiandiao123 Sep 4, 2023
b9fbf13
[Kernel]Rmsnorm fix (#4598)
tiandiao123 Sep 4, 2023
da77c97
[Bug Fix]Fix bugs in llama (#4601)
isky-cd Sep 4, 2023
b34e44e
[kernel] Add triton layer norm & replace norm for bloom (#4609)
yuanheng-zhao Sep 5, 2023
467f870
[Infer] Bug fix rotary embedding in llama (#4608)
Xu-Kai Sep 5, 2023
f8b28ec
[bench] Add bloom inference benchmark (#4621)
yuanheng-zhao Sep 5, 2023
19fc77d
trivial - uncomment for testing (#4622)
yuanheng-zhao Sep 5, 2023
ab73976
[Infer] add check triton and cuda version for tests (#4627)
Xu-Kai Sep 6, 2023
f13e787
Update sharder.py (#4629)
CjhHa1 Sep 6, 2023
8e3a8b5
[Inference] Hot fix some bugs and typos (#4632)
CjhHa1 Sep 6, 2023
be764b3
[typo]Comments fix (#4633)
tiandiao123 Sep 6, 2023
cd99da5
bug fix: fix some bugs in test_llama and test_bloom (#4635)
isky-cd Sep 6, 2023
7d4b00b
[Infer] delete benchmark in tests and fix bug for llama and bloom (#4…
Xu-Kai Sep 6, 2023
c5dc478
[Fix] Revise TPInferEngine, inference tests and benchmarks (#4642)
yuanheng-zhao Sep 6, 2023
2a98d75
modify utils filename for infer ops test (#4657)
yuanheng-zhao Sep 7, 2023
e2e96d4
[Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)
yuanheng-zhao Sep 8, 2023
c90a0d3
[NFC] use args for infer benchmarks (#4674)
yuanheng-zhao Sep 8, 2023
f0e12d8
revise infer default (#4683)
CjhHa1 Sep 11, 2023
02be854
[Fix] optimize/shard model in TPInferEngine init (#4684)
yuanheng-zhao Sep 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,35 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

---------------- LICENSE FOR VLLM TEAM ----------------

from VLLM TEAM:

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/vllm-project/vllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

---------------- LICENSE FOR LIGHTLLM TEAM ----------------

from LIGHTLLM TEAM:

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/ModelTC/lightllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
117 changes: 117 additions & 0 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# 🚀 Colossal-Inference

## Table of contents

## Introduction

`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.

## Design

Colossal Inference is composed of two main components:

1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.

## Pipeline of inference:

In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.

![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Colossal-inference.png)

## Roadmap of our implementation

- [x] Design cache manager and batch infer state
- [x] Design TpInference engine to integrates with `Shardformer`
- [x] Register corresponding high-performance `kernel` and `ops`
- [x] Design policies and forwards (e.g. `Llama` and `Bloom`)
- [x] policy
- [x] context forward
- [x] token forward
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
- [ ] Support all models
- [x] Llama
- [x] Bloom
- [ ] Chatglm2
- [ ] Benchmarking for all models

## Get started

### Installation

```bash
pip install -e .
```

### Requirements

dependencies

```bash
pytorch= 1.13.1 (gpu)
cuda>= 11.6
transformers= 4.30.2
triton==2.0.0.dev20221202
# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch
vllm
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
flash-attention
```

### Docker

You can use docker run to use docker container to set-up environment

```
# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support
docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash

```

### Dive into fast-inference!

example files are in

```bash
cd colossalai.examples
python xx
```

## Performance

### environment:

We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`.

For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future):

### Single GPU Performance:

Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned.

#### Llama

| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
| colossal-inference | 326.4 | 582.72 | 816.64 |

![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png)

### Bloom

| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
| colossal-inference | 323.28 | 538.52 | 611.64 |

![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png)

The results of more models are coming soon!
Empty file.
4 changes: 4 additions & 0 deletions colossalai/inference/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager

__all__ = ['MemoryManager', 'TPInferEngine']
55 changes: 55 additions & 0 deletions colossalai/inference/tensor_parallel/batch_infer_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
from typing import Any

import torch

from .kvcache_manager import MemoryManager


@dataclass
class BatchInferState:
r"""
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int

cache_manager: MemoryManager = None

block_loc: torch.Tensor = None
start_loc: torch.Tensor = None
seq_len: torch.Tensor = None
past_key_values_len: int = None

is_context_stage: bool = False
context_mem_index: torch.Tensor = None
decode_is_contiguous: bool = None
decode_mem_start: int = None
decode_mem_end: int = None
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None

device: torch.device = torch.device('cuda')

@property
def total_token_num(self):
# return self.batch_size * self.max_len_in_batch
assert self.seq_len is not None and self.seq_len.size(0) > 0
return int(torch.sum(self.seq_len))

def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager

@staticmethod
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
alloc_mem_index: torch.Tensor):
""" in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
cur_seq_len]
start_index += cur_seq_len
return
Loading