Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions colossalai/shardformer/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager

__all__ = ['BatchInferState', 'MemoryManager']
52 changes: 52 additions & 0 deletions colossalai/shardformer/inference/batch_infer_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
from typing import Any

import torch

from .kvcache_manager import MemoryManager


@dataclass
class BatchInferState:
r"""
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int

cache_manager: MemoryManager = None

block_loc: torch.Tensor = None
start_loc: torch.Tensor = None
seq_len: torch.Tensor = None

is_context_stage: bool = False
context_mem_index: torch.Tensor = None
decode_is_contiguous: bool = None
decode_mem_start: int = None
decode_mem_end: int = None
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None

device: torch.device = torch.device('cuda')

@property
def total_token_num(self):
return self.batch_size * self.max_len_in_batch

def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager

@staticmethod
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
alloc_mem_index: torch.Tensor):
""" in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
cur_seq_len]
start_index += cur_seq_len
return
116 changes: 116 additions & 0 deletions colossalai/shardformer/inference/kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Adapted from lightllm/common/mem_manager.py
# of the ModelTC/lightllm GitHub repository
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
#
# Copyright 2023 ModelTC Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from colossalai.logging import get_dist_logger


class MemoryManager:
r"""
Manage token block indexes and allocate physical memory for key and value cache

Args:
size: maximum token number used as the size of key and value buffer
dtype: data type of cached key and value
head_num: number of heads the memory manager is responsible for
head_dim: embedded size per head
layer_num: the number of layers in the model
device: device used to store the key and value cache
"""

def __init__(self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device('cuda')):
self.logger = get_dist_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
self._init_mem_states(size, device)
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)

def _init_mem_states(self, size, device):
""" Initialize tensors used to manage memory states """
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)

def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
""" Initialize key buffer and value buffer on specified device """
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
self.value_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]

@torch.no_grad()
def alloc(self, required_size):
""" allocate space of required_size by providing indexes representing available physical spaces """
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
select_index = self.indexes[select_index]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
return select_index

@torch.no_grad()
def alloc_contiguous(self, required_size):
""" allocate contiguous space of required_size """
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
1] + self.mem_state[0:sum_size -
required_size + 1]
can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
self.logger.info(f"No enough contiguous cache: required_size {required_size} "
f"left_size {self.available_size}")
return None
start_loc = can_used_loc[0]
select_index = self.indexes[start_loc:start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
end = start + required_size
return select_index, start, end

@torch.no_grad()
def free(self, free_index):
""" free memory by updating memory states based on given indexes """
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1

@torch.no_grad()
def free_all(self):
""" free all memory by updating memory states """
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.past_key_values_length = 0
self.logger.info("freed all space of memory manager")
60 changes: 60 additions & 0 deletions tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os

import pytest
import torch

from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.inference import MemoryManager
from colossalai.testing import rerun_if_address_is_in_use, spawn

BATCH_SIZE = 4
INPUT_LEN = 16
OUTPUT_LEN = 8
LAYER_NUM = 4
HEAD_NUM = 32
HEAD_DIM = 128


def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
Comment thread
yuanheng-zhao marked this conversation as resolved.
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
disable_existing_loggers()

size = batch_size * (input_len + output_len)
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
key_buffers = kvcache_manager.key_buffer
value_buffers = kvcache_manager.value_buffer
assert len(key_buffers) == len(value_buffers) == layer_num
assert key_buffers[0].shape == value_buffers[0].shape
# required size exceeds the maximum allocated size
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
assert invalid_locs is None
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
total_token_prefill = batch_size * input_len
prefill_locs = kvcache_manager.alloc(total_token_prefill)
kvcache_manager.free_all()
Comment thread
yuanheng-zhao marked this conversation as resolved.
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
assert torch.equal(prefill_locs, prefill_locs_contiguous)
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
kvcache_manager.alloc_contiguous(batch_size)
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager_dist():
spawn(create_cache_manager,
4,
batch_size=BATCH_SIZE,
input_len=INPUT_LEN,
output_len=OUTPUT_LEN,
layer_num=LAYER_NUM,
head_num=HEAD_NUM,
head_dim=HEAD_DIM)


if __name__ == '__main__':
test_cache_manager_dist()