From fbff5d3967e626a0f54887ce9b6303996f5658e1 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 11:04:42 +0800 Subject: [PATCH 1/8] add kv cache memory manager --- colossalai/inference/__init__.py | 0 colossalai/inference/kvcache_manager.py | 87 +++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 colossalai/inference/__init__.py create mode 100644 colossalai/inference/kvcache_manager.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/inference/kvcache_manager.py new file mode 100644 index 000000000000..fce788e4535c --- /dev/null +++ b/colossalai/inference/kvcache_manager.py @@ -0,0 +1,87 @@ +# Adapted from lightllm/common/mem_manager.py +# of the ModelTC/lightllm GitHub repository +# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py +# +# Copyright 2023 ModelTC Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from colossalai.logging import get_dist_logger + + +class MemoryManager: + def __init__(self, size, dtype, head_num, head_dim, layer_num, device=None): + device = torch.cuda.current_device() if device is None else device + self.logger = get_dist_logger(__name__) + self.available_size = size + self.past_key_values_length = 0 + + self._init_mem_states(size, device) + self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) + + def _init_mem_states(self, size, dev): + self.mem_state = torch.ones((size,), dtype=torch.bool, device=dev) + self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=dev) + self.indexes = torch.arange(0, size, dtype=torch.long, device=dev) + + def _init_kv_buffers(self, size, dev, dtype, head_num, head_dim, layer_num): + self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] + self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] + + @torch.no_grad() + def alloc(self, required_size): + if required_size > self.available_size: + self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + return None + + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) + select_index = torch.logical_and(self._mem_cum_sum <= required_size, self.mem_state == 1) + select_index = self.indexes[select_index] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + return select_index + + @torch.no_grad() + def alloc_contiguous(self, required_size): + if required_size > self.available_size: + self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + return None + + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) + sum_size = len(self._mem_cum_sum) + loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + 1] + self.mem_state[0:sum_size - required_size + 1] + can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] + if can_used_loc.shape[0] == 0: + self.logger.info(f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") + return None + start_loc = can_used_loc[0] + select_index = self.indexes[start_loc : start_loc + required_size] + + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + start = start_loc.item() + end = start + required_size + return select_index, start, end + + @torch.no_grad() + def free(self, free_index): + self.available_size += free_index.shape[0] + self.mem_state[free_index] = 1 + + @torch.no_grad() + def free_all(self): + self.available_size = len(self.mem_state) + self.mem_state[:] = 1 + self.past_key_values_length = 0 + self.logger.info("freed all space of memory manager") From 2d55acefc3f8cd6916460ee9567b0aacf188f324 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 15:32:12 +0800 Subject: [PATCH 2/8] add stateinfo during inference --- colossalai/inference/inference_state.py | 51 +++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 colossalai/inference/inference_state.py diff --git a/colossalai/inference/inference_state.py b/colossalai/inference/inference_state.py new file mode 100644 index 000000000000..8941495805bd --- /dev/null +++ b/colossalai/inference/inference_state.py @@ -0,0 +1,51 @@ +# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later +from dataclasses import dataclass +from typing import Any + +import torch + +from colossalai.inference.kvcache_manager import MemoryManager + + +@dataclass +class InferenceState: + batch_size: int + max_len_in_batch: int + + cache_manager: MemoryManager = None + + block_loc: torch.Tensor = None + start_loc: torch.Tensor = None + seq_len: torch.Tensor = None + + is_context_stage: bool = False + context_mem_index: torch.Tensor = None + decode_is_contiguous: bool = None + decode_mem_start: int = None + decode_mem_end: int = None + decode_mem_index: torch.Tensor = None + + device: torch.device = torch.device('cuda') + + @property + def total_token_num(self): + return self.batch_size * self.max_len_in_batch + + def set_cache_manager(self, manager: MemoryManager): + self.cache_manager = manager + + def step_inference_state(self): + self.start_loc = self.start_loc + torch.arange(0, self.batch_size, dtype=torch.int32, device=self.device) + self.seq_len += 1 + + @staticmethod + def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, + alloc_mem_index: torch.Tensor): + """ in-place update block loc mapping based on the sequence length of the inputs in current bath""" + start_index = 0 + seq_len_numpy = seq_len.cpu().numpy() + for i, cur_seq_len in enumerate(seq_len_numpy): + b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + + cur_seq_len] + start_index += cur_seq_len + return From cb45cf85cd0b900f37c60012ec3c1e8f3ee600ec Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 16:54:59 +0800 Subject: [PATCH 3/8] format --- colossalai/inference/kvcache_manager.py | 29 +++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/inference/kvcache_manager.py index fce788e4535c..e460f0a980f9 100644 --- a/colossalai/inference/kvcache_manager.py +++ b/colossalai/inference/kvcache_manager.py @@ -17,19 +17,21 @@ # limitations under the License. import torch + from colossalai.logging import get_dist_logger class MemoryManager: + def __init__(self, size, dtype, head_num, head_dim, layer_num, device=None): device = torch.cuda.current_device() if device is None else device self.logger = get_dist_logger(__name__) self.available_size = size self.past_key_values_length = 0 - + self._init_mem_states(size, device) self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) - + def _init_mem_states(self, size, dev): self.mem_state = torch.ones((size,), dtype=torch.bool, device=dev) self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=dev) @@ -38,47 +40,50 @@ def _init_mem_states(self, size, dev): def _init_kv_buffers(self, size, dev, dtype, head_num, head_dim, layer_num): self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] - + @torch.no_grad() def alloc(self, required_size): if required_size > self.available_size: self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") return None - + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) select_index = torch.logical_and(self._mem_cum_sum <= required_size, self.mem_state == 1) select_index = self.indexes[select_index] self.mem_state[select_index] = 0 self.available_size -= len(select_index) return select_index - + @torch.no_grad() def alloc_contiguous(self, required_size): if required_size > self.available_size: self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") return None - + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) sum_size = len(self._mem_cum_sum) - loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + 1] + self.mem_state[0:sum_size - required_size + 1] + loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + + 1] + self.mem_state[0:sum_size - + required_size + 1] can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] if can_used_loc.shape[0] == 0: - self.logger.info(f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") + self.logger.info( + f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") return None start_loc = can_used_loc[0] - select_index = self.indexes[start_loc : start_loc + required_size] - + select_index = self.indexes[start_loc:start_loc + required_size] + self.mem_state[select_index] = 0 self.available_size -= len(select_index) start = start_loc.item() end = start + required_size return select_index, start, end - + @torch.no_grad() def free(self, free_index): self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 - + @torch.no_grad() def free_all(self): self.available_size = len(self.mem_state) From 4f21bc5f1bc1b64f00ca5276974cd0c0043c294e Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 18:13:14 +0800 Subject: [PATCH 4/8] format --- colossalai/inference/inference_state.py | 7 ++- colossalai/inference/kvcache_manager.py | 72 ++++++++++++++++--------- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/colossalai/inference/inference_state.py b/colossalai/inference/inference_state.py index 8941495805bd..306ce8d6a5e3 100644 --- a/colossalai/inference/inference_state.py +++ b/colossalai/inference/inference_state.py @@ -8,7 +8,11 @@ @dataclass -class InferenceState: +class BatchInferState: + r""" + Information to be passed and used for a batch of inputs during + a single model forward + """ batch_size: int max_len_in_batch: int @@ -35,6 +39,7 @@ def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager def step_inference_state(self): + """ update indexes used for kv cache management at the end of model forward """ self.start_loc = self.start_loc + torch.arange(0, self.batch_size, dtype=torch.int32, device=self.device) self.seq_len += 1 diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/inference/kvcache_manager.py index e460f0a980f9..8f8c40a20890 100644 --- a/colossalai/inference/kvcache_manager.py +++ b/colossalai/inference/kvcache_manager.py @@ -22,33 +22,55 @@ class MemoryManager: + r""" + Manage token block indexes and allocate physical memory for key and value cache - def __init__(self, size, dtype, head_num, head_dim, layer_num, device=None): - device = torch.cuda.current_device() if device is None else device + Args: + size: maximum token number used as the size of key and value buffer + dtype: data type of cached key and value + head_num: number of heads the memory manager is responsible for + head_dim: embedded size per head + layer_num: the number of layers in the model + device: device used to store the key and value cache + """ + + def __init__(self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: torch.device = torch.device('cuda')): self.logger = get_dist_logger(__name__) self.available_size = size self.past_key_values_length = 0 - self._init_mem_states(size, device) self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) - def _init_mem_states(self, size, dev): - self.mem_state = torch.ones((size,), dtype=torch.bool, device=dev) - self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=dev) - self.indexes = torch.arange(0, size, dtype=torch.long, device=dev) + def _init_mem_states(self, size, device): + """ Initialize tensors used to manage memory states """ + self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) + self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) + self.indexes = torch.arange(0, size, dtype=torch.long, device=device) - def _init_kv_buffers(self, size, dev, dtype, head_num, head_dim, layer_num): - self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] - self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] + def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): + """ Initialize key buffer and value buffer on specified device """ + self.key_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + self.value_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] @torch.no_grad() def alloc(self, required_size): + """ allocate space of required_size by providing indexes representing available physical spaces """ if required_size > self.available_size: - self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " + f"left_size {self.available_size}") return None - - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) - select_index = torch.logical_and(self._mem_cum_sum <= required_size, self.mem_state == 1) + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) select_index = self.indexes[select_index] self.mem_state[select_index] = 0 self.available_size -= len(select_index) @@ -56,23 +78,23 @@ def alloc(self, required_size): @torch.no_grad() def alloc_contiguous(self, required_size): + """ allocate contiguous space of required_size """ if required_size > self.available_size: - self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " + f"left_size {self.available_size}") return None - - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) - sum_size = len(self._mem_cum_sum) - loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + - 1] + self.mem_state[0:sum_size - - required_size + 1] + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + sum_size = len(self.mem_cum_sum) + loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size + + 1] + self.mem_state[0:sum_size - + required_size + 1] can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] if can_used_loc.shape[0] == 0: - self.logger.info( - f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") + self.logger.info(f"No enough contiguous cache: required_size {required_size} " + f"left_size {self.available_size}") return None start_loc = can_used_loc[0] select_index = self.indexes[start_loc:start_loc + required_size] - self.mem_state[select_index] = 0 self.available_size -= len(select_index) start = start_loc.item() @@ -81,11 +103,13 @@ def alloc_contiguous(self, required_size): @torch.no_grad() def free(self, free_index): + """ free memory by updating memory states based on given indexes """ self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 @torch.no_grad() def free_all(self): + """ free all memory by updating memory states """ self.available_size = len(self.mem_state) self.mem_state[:] = 1 self.past_key_values_length = 0 From 389d0d4028a68b7b2fdeed03cf972090f248da13 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 18:14:22 +0800 Subject: [PATCH 5/8] rename file --- colossalai/inference/{inference_state.py => batch_infer_state.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename colossalai/inference/{inference_state.py => batch_infer_state.py} (100%) diff --git a/colossalai/inference/inference_state.py b/colossalai/inference/batch_infer_state.py similarity index 100% rename from colossalai/inference/inference_state.py rename to colossalai/inference/batch_infer_state.py From bdba1b560f0f90253a96dcdcfa70b829c548dc2b Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 24 Aug 2023 10:28:58 +0800 Subject: [PATCH 6/8] add kv cache test --- colossalai/inference/__init__.py | 4 ++ tests/test_infer/test_kvcache_manager.py | 60 ++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 tests/test_infer/test_kvcache_manager.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index e69de29bb2d1..1bce92653a8e 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -0,0 +1,4 @@ +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager + +__all__ = ['BatchInferState', 'MemoryManager'] diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py new file mode 100644 index 000000000000..2a34bb0a8c48 --- /dev/null +++ b/tests/test_infer/test_kvcache_manager.py @@ -0,0 +1,60 @@ +import os + +import pytest +import torch + +from colossalai.inference import MemoryManager +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +BATCH_SIZE = 4 +INPUT_LEN = 16 +OUTPUT_LEN = 8 +LAYER_NUM = 4 +HEAD_NUM = 32 +HEAD_DIM = 128 + + +def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + disable_existing_loggers() + + size = batch_size * (input_len + output_len) + kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank) + key_buffers = kvcache_manager.key_buffer + value_buffers = kvcache_manager.value_buffer + assert len(key_buffers) == len(value_buffers) == layer_num + assert key_buffers[0].shape == value_buffers[0].shape + # required size exceeds the maximum allocated size + invalid_locs = kvcache_manager.alloc_contiguous(size + 1) + assert invalid_locs is None + # for prefill stage, allocation via alloc and alloc_contiguous should be the same + total_token_prefill = batch_size * input_len + prefill_locs = kvcache_manager.alloc(total_token_prefill) + kvcache_manager.free_all() + prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0] + assert torch.equal(prefill_locs, prefill_locs_contiguous) + assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill + kvcache_manager.alloc_contiguous(batch_size) + assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_cache_manager_dist(): + spawn(create_cache_manager, + 4, + batch_size=BATCH_SIZE, + input_len=INPUT_LEN, + output_len=OUTPUT_LEN, + layer_num=LAYER_NUM, + head_num=HEAD_NUM, + head_dim=HEAD_DIM) + + +if __name__ == '__main__': + test_cache_manager_dist() From 813e23a7b3e20fea9e70ff41b7bbfc8120f13305 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 24 Aug 2023 11:34:56 +0800 Subject: [PATCH 7/8] revise on BatchInferState --- colossalai/inference/batch_infer_state.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/inference/batch_infer_state.py b/colossalai/inference/batch_infer_state.py index 306ce8d6a5e3..f06b0aadbf55 100644 --- a/colossalai/inference/batch_infer_state.py +++ b/colossalai/inference/batch_infer_state.py @@ -28,6 +28,7 @@ class BatchInferState: decode_mem_start: int = None decode_mem_end: int = None decode_mem_index: torch.Tensor = None + decode_layer_id: int = None device: torch.device = torch.device('cuda') @@ -38,11 +39,6 @@ def total_token_num(self): def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager - def step_inference_state(self): - """ update indexes used for kv cache management at the end of model forward """ - self.start_loc = self.start_loc + torch.arange(0, self.batch_size, dtype=torch.int32, device=self.device) - self.seq_len += 1 - @staticmethod def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor): From 05098b97b3e54113f043b15b3f0a271c6d415d7f Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 24 Aug 2023 16:26:44 +0800 Subject: [PATCH 8/8] file dir change --- colossalai/{ => shardformer}/inference/__init__.py | 0 colossalai/{ => shardformer}/inference/batch_infer_state.py | 2 +- colossalai/{ => shardformer}/inference/kvcache_manager.py | 0 tests/test_infer/test_kvcache_manager.py | 2 +- 4 files changed, 2 insertions(+), 2 deletions(-) rename colossalai/{ => shardformer}/inference/__init__.py (100%) rename colossalai/{ => shardformer}/inference/batch_infer_state.py (96%) rename colossalai/{ => shardformer}/inference/kvcache_manager.py (100%) diff --git a/colossalai/inference/__init__.py b/colossalai/shardformer/inference/__init__.py similarity index 100% rename from colossalai/inference/__init__.py rename to colossalai/shardformer/inference/__init__.py diff --git a/colossalai/inference/batch_infer_state.py b/colossalai/shardformer/inference/batch_infer_state.py similarity index 96% rename from colossalai/inference/batch_infer_state.py rename to colossalai/shardformer/inference/batch_infer_state.py index f06b0aadbf55..fef23a584b8b 100644 --- a/colossalai/inference/batch_infer_state.py +++ b/colossalai/shardformer/inference/batch_infer_state.py @@ -4,7 +4,7 @@ import torch -from colossalai.inference.kvcache_manager import MemoryManager +from .kvcache_manager import MemoryManager @dataclass diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/shardformer/inference/kvcache_manager.py similarity index 100% rename from colossalai/inference/kvcache_manager.py rename to colossalai/shardformer/inference/kvcache_manager.py diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index 2a34bb0a8c48..ef48444f73ca 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -3,8 +3,8 @@ import pytest import torch -from colossalai.inference import MemoryManager from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.inference import MemoryManager from colossalai.testing import rerun_if_address_is_in_use, spawn BATCH_SIZE = 4