From fbff5d3967e626a0f54887ce9b6303996f5658e1 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 11:04:42 +0800 Subject: [PATCH 01/28] add kv cache memory manager --- colossalai/inference/__init__.py | 0 colossalai/inference/kvcache_manager.py | 87 +++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 colossalai/inference/__init__.py create mode 100644 colossalai/inference/kvcache_manager.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/inference/kvcache_manager.py new file mode 100644 index 000000000000..fce788e4535c --- /dev/null +++ b/colossalai/inference/kvcache_manager.py @@ -0,0 +1,87 @@ +# Adapted from lightllm/common/mem_manager.py +# of the ModelTC/lightllm GitHub repository +# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py +# +# Copyright 2023 ModelTC Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from colossalai.logging import get_dist_logger + + +class MemoryManager: + def __init__(self, size, dtype, head_num, head_dim, layer_num, device=None): + device = torch.cuda.current_device() if device is None else device + self.logger = get_dist_logger(__name__) + self.available_size = size + self.past_key_values_length = 0 + + self._init_mem_states(size, device) + self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) + + def _init_mem_states(self, size, dev): + self.mem_state = torch.ones((size,), dtype=torch.bool, device=dev) + self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=dev) + self.indexes = torch.arange(0, size, dtype=torch.long, device=dev) + + def _init_kv_buffers(self, size, dev, dtype, head_num, head_dim, layer_num): + self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] + self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] + + @torch.no_grad() + def alloc(self, required_size): + if required_size > self.available_size: + self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + return None + + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) + select_index = torch.logical_and(self._mem_cum_sum <= required_size, self.mem_state == 1) + select_index = self.indexes[select_index] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + return select_index + + @torch.no_grad() + def alloc_contiguous(self, required_size): + if required_size > self.available_size: + self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + return None + + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) + sum_size = len(self._mem_cum_sum) + loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + 1] + self.mem_state[0:sum_size - required_size + 1] + can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] + if can_used_loc.shape[0] == 0: + self.logger.info(f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") + return None + start_loc = can_used_loc[0] + select_index = self.indexes[start_loc : start_loc + required_size] + + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + start = start_loc.item() + end = start + required_size + return select_index, start, end + + @torch.no_grad() + def free(self, free_index): + self.available_size += free_index.shape[0] + self.mem_state[free_index] = 1 + + @torch.no_grad() + def free_all(self): + self.available_size = len(self.mem_state) + self.mem_state[:] = 1 + self.past_key_values_length = 0 + self.logger.info("freed all space of memory manager") From 2d55acefc3f8cd6916460ee9567b0aacf188f324 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 15:32:12 +0800 Subject: [PATCH 02/28] add stateinfo during inference --- colossalai/inference/inference_state.py | 51 +++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 colossalai/inference/inference_state.py diff --git a/colossalai/inference/inference_state.py b/colossalai/inference/inference_state.py new file mode 100644 index 000000000000..8941495805bd --- /dev/null +++ b/colossalai/inference/inference_state.py @@ -0,0 +1,51 @@ +# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later +from dataclasses import dataclass +from typing import Any + +import torch + +from colossalai.inference.kvcache_manager import MemoryManager + + +@dataclass +class InferenceState: + batch_size: int + max_len_in_batch: int + + cache_manager: MemoryManager = None + + block_loc: torch.Tensor = None + start_loc: torch.Tensor = None + seq_len: torch.Tensor = None + + is_context_stage: bool = False + context_mem_index: torch.Tensor = None + decode_is_contiguous: bool = None + decode_mem_start: int = None + decode_mem_end: int = None + decode_mem_index: torch.Tensor = None + + device: torch.device = torch.device('cuda') + + @property + def total_token_num(self): + return self.batch_size * self.max_len_in_batch + + def set_cache_manager(self, manager: MemoryManager): + self.cache_manager = manager + + def step_inference_state(self): + self.start_loc = self.start_loc + torch.arange(0, self.batch_size, dtype=torch.int32, device=self.device) + self.seq_len += 1 + + @staticmethod + def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, + alloc_mem_index: torch.Tensor): + """ in-place update block loc mapping based on the sequence length of the inputs in current bath""" + start_index = 0 + seq_len_numpy = seq_len.cpu().numpy() + for i, cur_seq_len in enumerate(seq_len_numpy): + b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + + cur_seq_len] + start_index += cur_seq_len + return From e55e565b74db5ee2143ee6065d5f71cf6c27d9f1 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 22 Aug 2023 19:07:22 +0800 Subject: [PATCH 03/28] add --- colossalai/shardformer/policies/llama.py | 4 ++++ tests/kit/model_zoo/torchrec/__init__.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5ee95f3be8fa..2f36eb1a814b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -263,3 +263,7 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in llama for sequence classification model""" return [] + + +class LlamaInferPolicy(LlamaPolicy): + pass diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * From a9715358974b9c07b5858db26698efcd24b45321 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 23 Aug 2023 15:47:15 +0800 Subject: [PATCH 04/28] add infer example --- colossalai/shardformer/modeling/llama.py | 81 ++++++++++++++++++++ colossalai/shardformer/policies/llama.py | 25 +++++- colossalai/shardformer/shard/shard_config.py | 7 ++ tests/test_infer/_utils.py | 58 ++++++++++++++ tests/test_infer/test_llama_infer.py | 55 +++++++++++++ 5 files changed, 223 insertions(+), 3 deletions(-) create mode 100644 tests/test_infer/_utils.py create mode 100644 tests/test_infer/test_llama_infer.py diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f1d2998bbee4..e7b4be701849 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -19,6 +19,7 @@ class LlamaPipelineForwards: under pipeline setting. ''' + @staticmethod def llama_model_forward( self: LlamaModel, input_ids: torch.LongTensor = None, @@ -169,6 +170,7 @@ def custom_forward(*inputs): # always return dict for imediate stage return {'hidden_states': hidden_states} + @staticmethod def llama_for_causal_lm_forward( self: LlamaForCausalLM, input_ids: torch.LongTensor = None, @@ -276,6 +278,7 @@ def llama_for_causal_lm_forward( hidden_states = outputs.get('hidden_states') return {'hidden_states': hidden_states} + @staticmethod def llama_for_sequence_classification_forward( self: LlamaForSequenceClassification, input_ids: torch.LongTensor = None, @@ -388,6 +391,84 @@ def llama_for_sequence_classification_forward( return {'hidden_states': hidden_states} +class LlamaInferenceForwards: + """ + This class holds forwards for llama infer + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[ + torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo + past_key_values: Optional[List[ + torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done. + inputs_embeds: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + inferinfo=None, + ): + # only keep the basic items + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.norm(hidden_states) + + if not return_dict: + return hidden_states + return BaseModelOutputWithPast(last_hidden_state=hidden_states,) + + def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2f36eb1a814b..731695c0d8e3 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -7,7 +7,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward +from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] @@ -265,5 +265,24 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] -class LlamaInferPolicy(LlamaPolicy): - pass +class LlamaModelInferPolicy(LlamaPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + policy = super().module_policy() + self.shard_config._infer() + + # example for replace layer or decoder + # if self.shard_config.enable_flash_attention: + # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + # 'forward': get_llama_flash_attention_forward(), + # }) + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + return policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0c28f115d018..35b526456b10 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -28,6 +28,7 @@ class ShardConfig: enable_all_optimization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False + inference_only: bool = False # pipeline_parallel_size: int # data_parallel_size: int @@ -57,3 +58,9 @@ def _turn_on_all_optimization(self): self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True + + def _infer(self): + """ + Set default params for inference. + """ + self.pipeline_stage_manager = None diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py new file mode 100644 index 000000000000..8e3c9ff64187 --- /dev/null +++ b/tests/test_infer/_utils.py @@ -0,0 +1,58 @@ +import copy + +import torch +import torch.distributed as dist +from torch import Tensor +from torch import distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Adam, Optimizer + +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer._utils import getattr_ +from colossalai.shardformer.policies.auto_policy import Policy +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor + + +def build_model( + model_fn, + enable_fused_normalization=False, + enable_tensor_parallelism=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, +): + # create new model + org_model = model_fn() + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + inference_only=True) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, shared_params = shard_former.optimize(model_copy) + return org_model.cuda(), sharded_model.cuda() + + +def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn): + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + # switch to train mode + original_model.train() + sharded_model.train() + # run forward + org_output = original_model(**data) + org_output = output_transform_fn(org_output) + + shard_output = sharded_model(**data) + shard_output = output_transform_fn(shard_output) + + return org_output, shard_output diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py new file mode 100644 index 000000000000..09a81ef7f0a8 --- /dev/null +++ b/tests/test_infer/test_llama_infer.py @@ -0,0 +1,55 @@ +import os + +import pytest +import torch +from torch import distributed as dist + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_infer._utils import build_model, run_infer + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config): + org_model, sharded_model = build_model(model_fn, **test_config) + + org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn) + + print('original output', org_output[0]) + print('infer output', infer_output[0]) + + +@parameterize('test_config', [{ + 'enable_flash_attention': False, +}]) +def run_llama_test(test_config): + + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_llama": + continue + check_infer(model_fn, data_gen_fn, output_transform_fn, test_config) + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 1) + + +if __name__ == "__main__": + test_llama() From 0ae5bb7778c4e6abac4efaf9023a7d9006b73337 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 23 Aug 2023 16:19:20 +0800 Subject: [PATCH 05/28] finish --- colossalai/shardformer/policies/auto_policy.py | 14 ++++++++++++-- colossalai/shardformer/shard/sharder.py | 3 ++- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/test_infer/_utils.py | 2 -- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index eec339c02872..43ea1c5ab7f6 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -1,5 +1,6 @@ import importlib from dataclasses import dataclass +from typing import Optional import torch.nn as nn @@ -130,6 +131,12 @@ class PolicyLocation: PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), } +_INFER_POLICY_LIST = { + # LlaMa + "transformers.models.llama.modeling_llama.LlamaModel": + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy") +} + def import_policy(policy_location: PolicyLocation) -> Policy: """ @@ -151,7 +158,7 @@ def _fullname(obj): return module + '.' + klass.__qualname__ -def get_autopolicy(model: nn.Module) -> Policy: +def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: r""" Return the auto policy for the model @@ -162,7 +169,10 @@ def get_autopolicy(model: nn.Module) -> Policy: :class:`Policy`: The auto policy for the model """ full_name = _fullname(model) - policy_location = _POLICY_LIST.get(full_name, None) + if inference_only: + policy_location = _INFER_POLICY_LIST.get(full_name, None) + else: + policy_location = _POLICY_LIST.get(full_name, None) if policy_location is None: raise NotImplementedError( diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 0ed745a1fc4a..39704ae5e3ec 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -27,7 +27,8 @@ class ModelSharder(object): def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: self.model = model - self.policy = get_autopolicy(self.model) if policy is None else policy + self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy + print(self.policy) self.shard_config = shard_config def shard(self) -> List[Dict[int, Tensor]]: diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py index 8e3c9ff64187..68eb605ef5e6 100644 --- a/tests/test_infer/_utils.py +++ b/tests/test_infer/_utils.py @@ -23,7 +23,6 @@ def build_model( enable_tensor_parallelism=False, enable_flash_attention=False, enable_jit_fused=False, - enable_sequence_parallelism=False, ): # create new model org_model = model_fn() @@ -33,7 +32,6 @@ def build_model( enable_tensor_parallelism=enable_tensor_parallelism, enable_flash_attention=enable_flash_attention, enable_jit_fused=enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism, inference_only=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) From a8f73864f4f1b082e819c1be4b6246763663533b Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 23 Aug 2023 16:35:43 +0800 Subject: [PATCH 06/28] finish --- tests/test_infer/_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py index 68eb605ef5e6..3d56cc3484a6 100644 --- a/tests/test_infer/_utils.py +++ b/tests/test_infer/_utils.py @@ -43,9 +43,6 @@ def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn): # prepare input data = data_gen_fn() data = {k: v.cuda() for k, v in data.items()} - # switch to train mode - original_model.train() - sharded_model.train() # run forward org_output = original_model(**data) org_output = output_transform_fn(org_output) From cb45cf85cd0b900f37c60012ec3c1e8f3ee600ec Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 16:54:59 +0800 Subject: [PATCH 07/28] format --- colossalai/inference/kvcache_manager.py | 29 +++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/inference/kvcache_manager.py index fce788e4535c..e460f0a980f9 100644 --- a/colossalai/inference/kvcache_manager.py +++ b/colossalai/inference/kvcache_manager.py @@ -17,19 +17,21 @@ # limitations under the License. import torch + from colossalai.logging import get_dist_logger class MemoryManager: + def __init__(self, size, dtype, head_num, head_dim, layer_num, device=None): device = torch.cuda.current_device() if device is None else device self.logger = get_dist_logger(__name__) self.available_size = size self.past_key_values_length = 0 - + self._init_mem_states(size, device) self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) - + def _init_mem_states(self, size, dev): self.mem_state = torch.ones((size,), dtype=torch.bool, device=dev) self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=dev) @@ -38,47 +40,50 @@ def _init_mem_states(self, size, dev): def _init_kv_buffers(self, size, dev, dtype, head_num, head_dim, layer_num): self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] - + @torch.no_grad() def alloc(self, required_size): if required_size > self.available_size: self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") return None - + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) select_index = torch.logical_and(self._mem_cum_sum <= required_size, self.mem_state == 1) select_index = self.indexes[select_index] self.mem_state[select_index] = 0 self.available_size -= len(select_index) return select_index - + @torch.no_grad() def alloc_contiguous(self, required_size): if required_size > self.available_size: self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") return None - + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) sum_size = len(self._mem_cum_sum) - loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + 1] + self.mem_state[0:sum_size - required_size + 1] + loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + + 1] + self.mem_state[0:sum_size - + required_size + 1] can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] if can_used_loc.shape[0] == 0: - self.logger.info(f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") + self.logger.info( + f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") return None start_loc = can_used_loc[0] - select_index = self.indexes[start_loc : start_loc + required_size] - + select_index = self.indexes[start_loc:start_loc + required_size] + self.mem_state[select_index] = 0 self.available_size -= len(select_index) start = start_loc.item() end = start + required_size return select_index, start, end - + @torch.no_grad() def free(self, free_index): self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 - + @torch.no_grad() def free_all(self): self.available_size = len(self.mem_state) From 4f21bc5f1bc1b64f00ca5276974cd0c0043c294e Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 18:13:14 +0800 Subject: [PATCH 08/28] format --- colossalai/inference/inference_state.py | 7 ++- colossalai/inference/kvcache_manager.py | 72 ++++++++++++++++--------- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/colossalai/inference/inference_state.py b/colossalai/inference/inference_state.py index 8941495805bd..306ce8d6a5e3 100644 --- a/colossalai/inference/inference_state.py +++ b/colossalai/inference/inference_state.py @@ -8,7 +8,11 @@ @dataclass -class InferenceState: +class BatchInferState: + r""" + Information to be passed and used for a batch of inputs during + a single model forward + """ batch_size: int max_len_in_batch: int @@ -35,6 +39,7 @@ def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager def step_inference_state(self): + """ update indexes used for kv cache management at the end of model forward """ self.start_loc = self.start_loc + torch.arange(0, self.batch_size, dtype=torch.int32, device=self.device) self.seq_len += 1 diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/inference/kvcache_manager.py index e460f0a980f9..8f8c40a20890 100644 --- a/colossalai/inference/kvcache_manager.py +++ b/colossalai/inference/kvcache_manager.py @@ -22,33 +22,55 @@ class MemoryManager: + r""" + Manage token block indexes and allocate physical memory for key and value cache - def __init__(self, size, dtype, head_num, head_dim, layer_num, device=None): - device = torch.cuda.current_device() if device is None else device + Args: + size: maximum token number used as the size of key and value buffer + dtype: data type of cached key and value + head_num: number of heads the memory manager is responsible for + head_dim: embedded size per head + layer_num: the number of layers in the model + device: device used to store the key and value cache + """ + + def __init__(self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: torch.device = torch.device('cuda')): self.logger = get_dist_logger(__name__) self.available_size = size self.past_key_values_length = 0 - self._init_mem_states(size, device) self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) - def _init_mem_states(self, size, dev): - self.mem_state = torch.ones((size,), dtype=torch.bool, device=dev) - self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=dev) - self.indexes = torch.arange(0, size, dtype=torch.long, device=dev) + def _init_mem_states(self, size, device): + """ Initialize tensors used to manage memory states """ + self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) + self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) + self.indexes = torch.arange(0, size, dtype=torch.long, device=device) - def _init_kv_buffers(self, size, dev, dtype, head_num, head_dim, layer_num): - self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] - self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device=dev) for _ in range(layer_num)] + def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): + """ Initialize key buffer and value buffer on specified device """ + self.key_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + self.value_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] @torch.no_grad() def alloc(self, required_size): + """ allocate space of required_size by providing indexes representing available physical spaces """ if required_size > self.available_size: - self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " + f"left_size {self.available_size}") return None - - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) - select_index = torch.logical_and(self._mem_cum_sum <= required_size, self.mem_state == 1) + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) select_index = self.indexes[select_index] self.mem_state[select_index] = 0 self.available_size -= len(select_index) @@ -56,23 +78,23 @@ def alloc(self, required_size): @torch.no_grad() def alloc_contiguous(self, required_size): + """ allocate contiguous space of required_size """ if required_size > self.available_size: - self.logger.warning(f"warn no enough cache required_size {required_size} left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " + f"left_size {self.available_size}") return None - - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum) - sum_size = len(self._mem_cum_sum) - loc_sums = self._mem_cum_sum[required_size - 1:] - self._mem_cum_sum[0:sum_size - required_size + - 1] + self.mem_state[0:sum_size - - required_size + 1] + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + sum_size = len(self.mem_cum_sum) + loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size + + 1] + self.mem_state[0:sum_size - + required_size + 1] can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] if can_used_loc.shape[0] == 0: - self.logger.info( - f"Not enough contiguous cache: required_size {required_size} left_size {self.available_size}") + self.logger.info(f"No enough contiguous cache: required_size {required_size} " + f"left_size {self.available_size}") return None start_loc = can_used_loc[0] select_index = self.indexes[start_loc:start_loc + required_size] - self.mem_state[select_index] = 0 self.available_size -= len(select_index) start = start_loc.item() @@ -81,11 +103,13 @@ def alloc_contiguous(self, required_size): @torch.no_grad() def free(self, free_index): + """ free memory by updating memory states based on given indexes """ self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 @torch.no_grad() def free_all(self): + """ free all memory by updating memory states """ self.available_size = len(self.mem_state) self.mem_state[:] = 1 self.past_key_values_length = 0 From 389d0d4028a68b7b2fdeed03cf972090f248da13 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 23 Aug 2023 18:14:22 +0800 Subject: [PATCH 09/28] rename file --- colossalai/inference/{inference_state.py => batch_infer_state.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename colossalai/inference/{inference_state.py => batch_infer_state.py} (100%) diff --git a/colossalai/inference/inference_state.py b/colossalai/inference/batch_infer_state.py similarity index 100% rename from colossalai/inference/inference_state.py rename to colossalai/inference/batch_infer_state.py From bdba1b560f0f90253a96dcdcfa70b829c548dc2b Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 24 Aug 2023 10:28:58 +0800 Subject: [PATCH 10/28] add kv cache test --- colossalai/inference/__init__.py | 4 ++ tests/test_infer/test_kvcache_manager.py | 60 ++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 tests/test_infer/test_kvcache_manager.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index e69de29bb2d1..1bce92653a8e 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -0,0 +1,4 @@ +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager + +__all__ = ['BatchInferState', 'MemoryManager'] diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py new file mode 100644 index 000000000000..2a34bb0a8c48 --- /dev/null +++ b/tests/test_infer/test_kvcache_manager.py @@ -0,0 +1,60 @@ +import os + +import pytest +import torch + +from colossalai.inference import MemoryManager +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn + +BATCH_SIZE = 4 +INPUT_LEN = 16 +OUTPUT_LEN = 8 +LAYER_NUM = 4 +HEAD_NUM = 32 +HEAD_DIM = 128 + + +def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + disable_existing_loggers() + + size = batch_size * (input_len + output_len) + kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank) + key_buffers = kvcache_manager.key_buffer + value_buffers = kvcache_manager.value_buffer + assert len(key_buffers) == len(value_buffers) == layer_num + assert key_buffers[0].shape == value_buffers[0].shape + # required size exceeds the maximum allocated size + invalid_locs = kvcache_manager.alloc_contiguous(size + 1) + assert invalid_locs is None + # for prefill stage, allocation via alloc and alloc_contiguous should be the same + total_token_prefill = batch_size * input_len + prefill_locs = kvcache_manager.alloc(total_token_prefill) + kvcache_manager.free_all() + prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0] + assert torch.equal(prefill_locs, prefill_locs_contiguous) + assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill + kvcache_manager.alloc_contiguous(batch_size) + assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_cache_manager_dist(): + spawn(create_cache_manager, + 4, + batch_size=BATCH_SIZE, + input_len=INPUT_LEN, + output_len=OUTPUT_LEN, + layer_num=LAYER_NUM, + head_num=HEAD_NUM, + head_dim=HEAD_DIM) + + +if __name__ == '__main__': + test_cache_manager_dist() From 813e23a7b3e20fea9e70ff41b7bbfc8120f13305 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 24 Aug 2023 11:34:56 +0800 Subject: [PATCH 11/28] revise on BatchInferState --- colossalai/inference/batch_infer_state.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/inference/batch_infer_state.py b/colossalai/inference/batch_infer_state.py index 306ce8d6a5e3..f06b0aadbf55 100644 --- a/colossalai/inference/batch_infer_state.py +++ b/colossalai/inference/batch_infer_state.py @@ -28,6 +28,7 @@ class BatchInferState: decode_mem_start: int = None decode_mem_end: int = None decode_mem_index: torch.Tensor = None + decode_layer_id: int = None device: torch.device = torch.device('cuda') @@ -38,11 +39,6 @@ def total_token_num(self): def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager - def step_inference_state(self): - """ update indexes used for kv cache management at the end of model forward """ - self.start_loc = self.start_loc + torch.arange(0, self.batch_size, dtype=torch.int32, device=self.device) - self.seq_len += 1 - @staticmethod def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor): From 469a3c5f27143ea0eb09b22cef4c3940e51a6d3d Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 24 Aug 2023 16:15:25 +0800 Subject: [PATCH 12/28] add inference test for llama --- colossalai/shardformer/modeling/llama.py | 293 +++++++++++++++++++++-- colossalai/shardformer/policies/llama.py | 8 + tests/test_infer/llama_infer_eigine.py | 236 ++++++++++++++++++ tests/test_infer/test_llama_infer.py | 38 ++- 4 files changed, 532 insertions(+), 43 deletions(-) create mode 100644 tests/test_infer/llama_infer_eigine.py diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e7b4be701849..4c23cc50feb7 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,14 +1,16 @@ from typing import Callable, List, Optional, Tuple import torch +import numpy as np from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaAttention from transformers.utils import logging +from colossalai.inference.batch_infer_state import BatchInferState from colossalai.pipeline.stage_manager import PipelineStageManager @@ -401,15 +403,32 @@ def llama_model_forward( self: LlamaModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[ - torch.LongTensor] = None, # TODO: this can also be removed if we got sin,cos cached in inferinfo - past_key_values: Optional[List[ - torch.FloatTensor]] = None, #TODO: maybe removed after memory cache manager is done. + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - inferinfo=None, ): - # only keep the basic items + batch_size = input_ids.shape[0] # input_ids.shape[0] + + infer_info = BatchInferState(batch_size, input_ids.shape[1]) + infer_info.batch_size = batch_size + # NOTE: dummy implementation here for testing, just assume all inputs same length + infer_info.block_loc = self.block_loc + infer_info.start_loc = self.start_loc + infer_info.seq_len = self.seq_len + infer_info.max_len_in_batch = self.max_len_in_batch + + b_seq_len_numpy = infer_info.seq_len.cpu().numpy() + position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) + for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + + # this equals + infer_info.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) + infer_info.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds @@ -426,49 +445,283 @@ def llama_model_forward( past_key_values_length = 0 if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + # TODO dummy but work, revise it + past_key_values_length = self.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length + infer_info.set_cache_manager(self.cache_manager) + + # FIXME: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assuem prefill stage + # allocate memory block + infer_info.is_context_stage = True # set prefill stage, notify attention layer + infer_info.context_mem_index = infer_info.cache_manager.alloc(infer_info.total_token_num) + infer_info.init_block_loc(infer_info.block_loc, infer_info.seq_len, seq_length, infer_info.context_mem_index) + else: + # TODO handle the condition that no contiguous memory presents + alloc_mem = self.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_info.decode_mem_index = alloc_mem[0] + infer_info.decode_mem_start = alloc_mem[1] + infer_info.decode_mem_end = alloc_mem[2] + infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print(f" infer_info.cache_manager.past_key_values_length: {infer_info.cache_manager.past_key_values_length}") + infer_info.decode_is_contiguous = False + alloc_mem = self.cache_manager.alloc(batch_size) + infer_info.decode_mem_index = alloc_mem + # infer_info.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_info.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # embed positions if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + infer_info.decode_layer_id = 0 + for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None + # NOTE: modify here for passing args to decoder layer layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_info=infer_info, ) - + infer_info.decode_layer_id += 1 hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) hidden_states = self.norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + + # update indices + self.max_len_in_batch += 1 + self.block_loc[:, self.max_len_in_batch-1] = self.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + self.start_loc = self.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + self.total_token_num += batch_size + self.seq_len += 1 if not return_dict: - return hidden_states - return BaseModelOutputWithPast(last_hidden_state=hidden_states,) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + @staticmethod + def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_info: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_info=infer_info, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_info: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # TODO might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states_transposed = key_states.transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + # cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len) + cos ,sin = infer_info.position_cos, infer_info.position_sin + + cos_sin_cache = torch.cat((cos, sin), dim=-1) + + from col_pos_encoding_ops import rotary_embedding_neox + + rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) + + from inference.ops.triton.k_copy_kv import destindex_copy_kv + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + num_heads = key_buffer.shape[2] + head_dim = key_buffer.shape[3] + key_buffer = key_buffer.view(-1, num_heads, head_dim) + value_buffer = value_buffer.view(-1, num_heads, head_dim) + destindex_copy_kv(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + destindex_copy_kv(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + # copy key and value calculated in current step to memory manager + if infer_info.is_context_stage: + _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.context_mem_index, infer_info.cache_manager) + else: + _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.decode_mem_index, infer_info.cache_manager) + + # this is worse than destcopy + # torch.Tensor.copy_(infer_info.cache_manager.key_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],key_states) + # torch.Tensor.copy_(infer_info.cache_manager.value_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],value_states) + + # FIXME might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_info.decode_layer_id == 0: # once per model.forward + infer_info.cache_manager.past_key_values_length += q_len # seq_len + + query_states = query_states.transpose(1, 2) + + if infer_info.is_context_stage: + # first token generation + + # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, + # key_states, + # value_states, + # 0, + # 1/math.sqrt(self.head_dim), + # causal, + # False) + + attn_output = torch.empty_like(query_states) + + # calcu_shape for context_attention_fwd + calcu_shape1 = (-1, self.num_heads, self.head_dim) + + from inference.lightllm.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd + + context_attention_fwd(query_states.view(calcu_shape1), + key_states.view(calcu_shape1), + value_states.view(calcu_shape1), + attn_output.view(calcu_shape1), + infer_info.start_loc, + infer_info.seq_len, + infer_info.max_len_in_batch) + + else: + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + calcu_shape1 = (-1, self.num_heads, self.head_dim) + att_m_tensor = torch.empty((self.num_heads, infer_info.total_token_num), dtype=query_states.dtype, device="cuda") + + from inference.lightllm.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + from inference.lightllm.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 + from inference.lightllm.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd + + # q*k + token_att_fwd(query_states.view(calcu_shape1), + infer_info.cache_manager.key_buffer[infer_info.decode_layer_id], + att_m_tensor, + infer_info.block_loc, + infer_info.start_loc, + infer_info.seq_len, + infer_info.max_len_in_batch) + + prob = torch.empty_like(att_m_tensor) + token_softmax_fwd(att_m_tensor, infer_info.start_loc, infer_info.seq_len, prob, infer_info.max_len_in_batch) + att_m_tensor = None + + attn_output = torch.empty_like(query_states) + + token_att_fwd2(prob, + infer_info.cache_manager.value_buffer[infer_info.decode_layer_id], + attn_output.view(calcu_shape1), + infer_info.block_loc, + infer_info.start_loc, + infer_info.seq_len, + infer_info.max_len_in_batch) + + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + # return past_key_value as None + return attn_output, None, None + def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 731695c0d8e3..0cc0bd39d05e 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -284,5 +284,13 @@ def module_policy(self): infer_forward = LlamaInferenceForwards.llama_model_forward method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) return policy diff --git a/tests/test_infer/llama_infer_eigine.py b/tests/test_infer/llama_infer_eigine.py new file mode 100644 index 000000000000..1f0d9c682031 --- /dev/null +++ b/tests/test_infer/llama_infer_eigine.py @@ -0,0 +1,236 @@ + +import torch.distributed as dist +import torch +import torch.nn as nn +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.kvcache_manager import MemoryManager +from colossalai.shardformer.policies.llama import LlamaModelInferPolicy +from transformers import LlamaForCausalLM, LlamaTokenizer +import time +from torch.profiler import profile, record_function, ProfilerActivity + +GIGABYTE = 1024 ** 3 +torch.backends.cudnn.enabled = True + +def print_device_memory(): + if torch.cuda.is_available(): + current_device = torch.cuda.current_device() + print(f"Currently using GPU: {current_device}") + + # free memory and the total available memory in bytes + global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() + memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.memory_reserved() + max_memory_reserved = torch.cuda.max_memory_reserved() + + print( + f" free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" + f" total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" + f" memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" + f" Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" + f" memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" + f" Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" + ) + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print(f"num_layers: {num_layers}, hidden_size: {config.hidden_size}") + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000/(avg * 1000)))) + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config,"max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config,"max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + +class TPCacheManagerInferenceEngine: + + def __init__( + self, + input_len: int, + output_len: int, + bs: int, + tp_size: int, + ) -> None: + self.pg_mesh = ProcessGroupMesh(1, 1, tp_size) + self.input_len = input_len + self.output_len = output_len + self.bs = bs + self.tp_size = tp_size + + def init_and_insert_cache_manager(self): + + max_total_token_num = self.bs * (self.input_len + self.output_len) + + head_num = self.model.config.num_attention_heads // self.tp_size + + self.cache_manager = MemoryManager( + max_total_token_num, + torch.float16, + head_num, + self.model.config.hidden_size // self.model.config.num_attention_heads, + self.model.config.num_hidden_layers, + device="cuda", + ) + + setattr(self.model.model, 'cache_manager', self.cache_manager) + + block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") + start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + max_total_token_num = self.bs * (self.input_len) + max_len_in_batch = self.input_len + for i in range(self.bs): + block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") + start_loc[i] = i * self.input_len + seq_len[i] = self.input_len + print(block_loc.shape) + print(start_loc.shape) + print(seq_len.shape) + setattr(self.model.model, 'block_loc', block_loc) + setattr(self.model.model, 'start_loc', start_loc) + setattr(self.model.model, 'seq_len', seq_len) + setattr(self.model.model,'total_token_num', max_total_token_num) + setattr(self.model.model,'max_len_in_batch',max_len_in_batch) + + + def prepare_model(self): + llama_model_path = "/data/scratch/llama-7b-hf" + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + self.model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(self.model.model, base=10000) + self.model = self.model.half() + self.model.to(torch.cuda.current_device()) + + + def build_model(self): + # create new model + org_model = self.model + shardconfig = ShardConfig( + tensor_parallel_process_group=self.pg_mesh.get_group_along_axis(2), + enable_tensor_parallelism=True, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(self.model, LlamaModelInferPolicy()) + return org_model.cuda(), shard_model.cuda() + + def generate_data(self): + self.input_tokens={"input_ids":torch.randint(1, 1000, (self.bs, self.input_len))} + for t in self.input_tokens: + if torch.is_tensor(self.input_tokens[t]): + self.input_tokens[t] = self.input_tokens[t].to(torch.cuda.current_device()) + print(f" input_tokens[{t}].shape: {self.input_tokens[t].shape}") + self.input_len = self.input_tokens[t].shape[1] + print(f" input_len: {self.input_len}") + + def run_infer(self): + generate_kwargs = dict(max_new_tokens=self.output_len, do_sample=False) + + iters = 10 + times = [] + warmup = 3 + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = self.model.generate(**self.input_tokens, + **generate_kwargs, early_stopping=False) + torch.cuda.synchronize() + + # manually clear the cache manager bounded to the model + # TODO: add post-process of model generate + # we might want to replace model.generate + # QUESTION: model.generate, or model.model.generate + + end = time.time() + self.model.model.cache_manager.free_all() + num_tokens_generation = outputs.shape[1] - self.input_len + print(f"num_tokens_generation: {num_tokens_generation}") + print(f"generation time is {(end - start) * 1000} ms") + time_spend = (end-start)/num_tokens_generation + times.append(time_spend) + + # print(outputs.shape) + # reset params + block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") + start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + max_len_in_batch = self.input_len + for i in range(self.bs): + block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") + start_loc[i] = i * self.input_len + seq_len[i] = self.input_len + + setattr(self.model.model, 'block_loc', block_loc) + setattr(self.model.model, 'start_loc', start_loc) + setattr(self.model.model, 'seq_len', seq_len) + setattr(self.model.model,'max_len_in_batch',max_len_in_batch) + + print_device_memory() + total_time = (end - start) * 1000 + token_latency = total_time/(self.output_len) + print(outputs.shape) + print('per_token_latency',token_latency,'ms') + print_perf_stats(times, self.model.config, self.bs, warmup=warmup) + + + # reset params for profile + block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") + start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + max_len_in_batch = self.input_len + for i in range(self.bs): + block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") + start_loc[i] = i * self.input_len + seq_len[i] = self.input_len + + setattr(self.model.model, 'block_loc', block_loc) + setattr(self.model.model, 'start_loc', start_loc) + setattr(self.model.model, 'seq_len', seq_len) + setattr(self.model.model,'max_len_in_batch',max_len_in_batch) + + + with profile(activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + torch.cuda.synchronize() + outputs = self.model.generate(**self.input_tokens, + **generate_kwargs,early_stopping=False) + torch.cuda.synchronize() + + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + + \ No newline at end of file diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 09a81ef7f0a8..54253e99f9e4 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -2,39 +2,31 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_infer._utils import build_model, run_infer +from llama_infer_eigine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - -def check_infer(model_fn, data_gen_fn, output_transform_fn, test_config): - org_model, sharded_model = build_model(model_fn, **test_config) - - org_output, infer_output = run_infer(org_model, sharded_model, data_gen_fn, output_transform_fn) - - print('original output', org_output[0]) - print('infer output', infer_output[0]) - - @parameterize('test_config', [{ - 'enable_flash_attention': False, + 'tp_size': 2, }]) def run_llama_test(test_config): + input_len = 1024 + output_len = 128 + bs = 8 + engine = TPCacheManagerInferenceEngine(input_len, output_len, bs, 2) + engine.generate_data() + engine.prepare_model() + engine.init_and_insert_cache_manager() + + org_model, sharded_model = engine.build_model() + engine.model = sharded_model + + engine.run_infer() - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != "transformers_llama": - continue - check_infer(model_fn, data_gen_fn, output_transform_fn, test_config) torch.cuda.empty_cache() @@ -48,7 +40,7 @@ def check_llama(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): - spawn(check_llama, 1) + spawn(check_llama, 2) if __name__ == "__main__": From 7686c07d9ab2afe0cb70e52fc64d4e8ea60c1eb9 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 24 Aug 2023 16:50:35 +0800 Subject: [PATCH 13/28] fix conflict --- colossalai/inference/__init__.py | 4 - colossalai/inference/batch_infer_state.py | 52 ---------- colossalai/inference/kvcache_manager.py | 116 ---------------------- 3 files changed, 172 deletions(-) delete mode 100644 colossalai/inference/__init__.py delete mode 100644 colossalai/inference/batch_infer_state.py delete mode 100644 colossalai/inference/kvcache_manager.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py deleted file mode 100644 index 1bce92653a8e..000000000000 --- a/colossalai/inference/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .batch_infer_state import BatchInferState -from .kvcache_manager import MemoryManager - -__all__ = ['BatchInferState', 'MemoryManager'] diff --git a/colossalai/inference/batch_infer_state.py b/colossalai/inference/batch_infer_state.py deleted file mode 100644 index f06b0aadbf55..000000000000 --- a/colossalai/inference/batch_infer_state.py +++ /dev/null @@ -1,52 +0,0 @@ -# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later -from dataclasses import dataclass -from typing import Any - -import torch - -from colossalai.inference.kvcache_manager import MemoryManager - - -@dataclass -class BatchInferState: - r""" - Information to be passed and used for a batch of inputs during - a single model forward - """ - batch_size: int - max_len_in_batch: int - - cache_manager: MemoryManager = None - - block_loc: torch.Tensor = None - start_loc: torch.Tensor = None - seq_len: torch.Tensor = None - - is_context_stage: bool = False - context_mem_index: torch.Tensor = None - decode_is_contiguous: bool = None - decode_mem_start: int = None - decode_mem_end: int = None - decode_mem_index: torch.Tensor = None - decode_layer_id: int = None - - device: torch.device = torch.device('cuda') - - @property - def total_token_num(self): - return self.batch_size * self.max_len_in_batch - - def set_cache_manager(self, manager: MemoryManager): - self.cache_manager = manager - - @staticmethod - def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, - alloc_mem_index: torch.Tensor): - """ in-place update block loc mapping based on the sequence length of the inputs in current bath""" - start_index = 0 - seq_len_numpy = seq_len.cpu().numpy() - for i, cur_seq_len in enumerate(seq_len_numpy): - b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + - cur_seq_len] - start_index += cur_seq_len - return diff --git a/colossalai/inference/kvcache_manager.py b/colossalai/inference/kvcache_manager.py deleted file mode 100644 index 8f8c40a20890..000000000000 --- a/colossalai/inference/kvcache_manager.py +++ /dev/null @@ -1,116 +0,0 @@ -# Adapted from lightllm/common/mem_manager.py -# of the ModelTC/lightllm GitHub repository -# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py -# -# Copyright 2023 ModelTC Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from colossalai.logging import get_dist_logger - - -class MemoryManager: - r""" - Manage token block indexes and allocate physical memory for key and value cache - - Args: - size: maximum token number used as the size of key and value buffer - dtype: data type of cached key and value - head_num: number of heads the memory manager is responsible for - head_dim: embedded size per head - layer_num: the number of layers in the model - device: device used to store the key and value cache - """ - - def __init__(self, - size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: torch.device = torch.device('cuda')): - self.logger = get_dist_logger(__name__) - self.available_size = size - self.past_key_values_length = 0 - self._init_mem_states(size, device) - self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) - - def _init_mem_states(self, size, device): - """ Initialize tensors used to manage memory states """ - self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) - self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) - self.indexes = torch.arange(0, size, dtype=torch.long, device=device) - - def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): - """ Initialize key buffer and value buffer on specified device """ - self.key_buffer = [ - torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) - ] - self.value_buffer = [ - torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) - ] - - @torch.no_grad() - def alloc(self, required_size): - """ allocate space of required_size by providing indexes representing available physical spaces """ - if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " - f"left_size {self.available_size}") - return None - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) - select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) - select_index = self.indexes[select_index] - self.mem_state[select_index] = 0 - self.available_size -= len(select_index) - return select_index - - @torch.no_grad() - def alloc_contiguous(self, required_size): - """ allocate contiguous space of required_size """ - if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " - f"left_size {self.available_size}") - return None - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) - sum_size = len(self.mem_cum_sum) - loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size + - 1] + self.mem_state[0:sum_size - - required_size + 1] - can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] - if can_used_loc.shape[0] == 0: - self.logger.info(f"No enough contiguous cache: required_size {required_size} " - f"left_size {self.available_size}") - return None - start_loc = can_used_loc[0] - select_index = self.indexes[start_loc:start_loc + required_size] - self.mem_state[select_index] = 0 - self.available_size -= len(select_index) - start = start_loc.item() - end = start + required_size - return select_index, start, end - - @torch.no_grad() - def free(self, free_index): - """ free memory by updating memory states based on given indexes """ - self.available_size += free_index.shape[0] - self.mem_state[free_index] = 1 - - @torch.no_grad() - def free_all(self): - """ free all memory by updating memory states """ - self.available_size = len(self.mem_state) - self.mem_state[:] = 1 - self.past_key_values_length = 0 - self.logger.info("freed all space of memory manager") From ba089d7c792a3d4ce94343a19f15f586f07aee33 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 24 Aug 2023 17:15:21 +0800 Subject: [PATCH 14/28] feature: add some new features for llama engine --- colossalai/shardformer/modeling/llama.py | 8 +-- tests/test_infer/llama_infer_eigine.py | 75 +++++++++--------------- tests/test_infer/test_llama_infer.py | 6 +- 3 files changed, 33 insertions(+), 56 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 58c938e64497..eeffcef598ad 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -8,13 +8,9 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -<<<<<<< HEAD -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaAttention -======= -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm ->>>>>>> 7d7ea2ef41486c3c6f8c3595d482c6e15b403bff +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm, LlamaAttention from transformers.utils import logging -from colossalai.inference.batch_infer_state import BatchInferState +from colossalai.shardformer.inference import BatchInferState from colossalai.pipeline.stage_manager import PipelineStageManager diff --git a/tests/test_infer/llama_infer_eigine.py b/tests/test_infer/llama_infer_eigine.py index 1b70ecbc68fa..640c29b7e3c9 100644 --- a/tests/test_infer/llama_infer_eigine.py +++ b/tests/test_infer/llama_infer_eigine.py @@ -4,7 +4,7 @@ import torch.nn as nn from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.inference.kvcache_manager import MemoryManager +from colossalai.shardformer.inference import MemoryManager from colossalai.shardformer.policies.llama import LlamaModelInferPolicy, LlamaPolicy from transformers import LlamaForCausalLM, LlamaTokenizer import time @@ -138,7 +138,7 @@ def prepare_model(self): def build_model(self): # create new model - org_model = self.model + self.orgin_model = self.model shardconfig = ShardConfig( tensor_parallel_process_group=self.pg_mesh.get_group_along_axis(2), enable_tensor_parallelism=True, @@ -151,8 +151,7 @@ def build_model(self): else: policy = LlamaPolicy() - shard_model, _ = shardformer.optimize(self.model, policy) - return org_model.cuda(), shard_model.cuda() + self.model, _ = shardformer.optimize(self.model, policy) def generate_data(self): self.input_tokens={"input_ids":torch.randint(1, 1000, (self.bs, self.input_len))} @@ -163,40 +162,49 @@ def generate_data(self): self.input_len = self.input_tokens[t].shape[1] print(f" input_len: {self.input_len}") - def run_infer(self): + def run_infer(self, test_origin=True): + + if test_origin: + model = self.orgin_model + else: + model = self.model + generate_kwargs = dict(max_new_tokens=self.output_len, do_sample=False) iters = 10 times = [] + outputs_list = [] warmup = 3 for i in range(iters): torch.cuda.synchronize() start = time.time() - outputs = self.model.generate(**self.input_tokens, + outputs = model.generate(**self.input_tokens, **generate_kwargs, early_stopping=False) + outputs_list.append(outputs) torch.cuda.synchronize() end = time.time() - self.model.model.cache_manager.free_all() num_tokens_generation = outputs.shape[1] - self.input_len print(f"num_tokens_generation: {num_tokens_generation}") print(f"generation time is {(end - start) * 1000} ms") time_spend = (end-start)/num_tokens_generation times.append(time_spend) - block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") - start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - max_len_in_batch = self.input_len - for i in range(self.bs): - block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") - start_loc[i] = i * self.input_len - seq_len[i] = self.input_len - - setattr(self.model.model, 'block_loc', block_loc) - setattr(self.model.model, 'start_loc', start_loc) - setattr(self.model.model, 'seq_len', seq_len) - setattr(self.model.model,'max_len_in_batch',max_len_in_batch) + if test_origin: + model.model.cache_manager.free_all() + block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") + start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") + max_len_in_batch = self.input_len + for i in range(self.bs): + block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") + start_loc[i] = i * self.input_len + seq_len[i] = self.input_len + + setattr(model.model, 'block_loc', block_loc) + setattr(model.model, 'start_loc', start_loc) + setattr(model.model, 'seq_len', seq_len) + setattr(model.model,'max_len_in_batch',max_len_in_batch) print_device_memory() total_time = (end - start) * 1000 @@ -204,33 +212,8 @@ def run_infer(self): print(outputs.shape) print('per_token_latency',token_latency,'ms') print_perf_stats(times, self.model.config, self.bs, warmup=warmup) - - - # reset params for profile - block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") - start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - max_len_in_batch = self.input_len - for i in range(self.bs): - block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") - start_loc[i] = i * self.input_len - seq_len[i] = self.input_len - setattr(self.model.model, 'block_loc', block_loc) - setattr(self.model.model, 'start_loc', start_loc) - setattr(self.model.model, 'seq_len', seq_len) - setattr(self.model.model,'max_len_in_batch',max_len_in_batch) - - - with profile(activities=[ - ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("model_inference"): - torch.cuda.synchronize() - outputs = self.model.generate(**self.input_tokens, - **generate_kwargs,early_stopping=False) - torch.cuda.synchronize() - - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + return outputs_list \ No newline at end of file diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 54253e99f9e4..0103b17c0b71 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -20,12 +20,10 @@ def run_llama_test(test_config): engine = TPCacheManagerInferenceEngine(input_len, output_len, bs, 2) engine.generate_data() engine.prepare_model() - engine.init_and_insert_cache_manager() - org_model, sharded_model = engine.build_model() - engine.model = sharded_model + engine.build_model() - engine.run_infer() + engine.run_infer(test_origin=True) torch.cuda.empty_cache() From 68b5fe83a06401d65c0319226d6ce5b0cb6b90f5 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 24 Aug 2023 18:26:21 +0800 Subject: [PATCH 15/28] adapt colossalai triton interface --- colossalai/shardformer/modeling/llama.py | 51 ++++++------------- ..._infer_eigine.py => llama_infer_engine.py} | 7 ++- tests/test_infer/test_llama_infer.py | 4 +- 3 files changed, 20 insertions(+), 42 deletions(-) rename tests/test_infer/{llama_infer_eigine.py => llama_infer_engine.py} (98%) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index eeffcef598ad..75f3d79555dd 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -11,6 +11,9 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm, LlamaAttention from transformers.utils import logging from colossalai.shardformer.inference import BatchInferState +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.context_attention import llama_context_attn_fwd +from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager @@ -625,15 +628,13 @@ def llama_flash_attn_kvcache_forward( rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - from inference.ops.triton.k_copy_kv import destindex_copy_kv - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): num_heads = key_buffer.shape[2] head_dim = key_buffer.shape[3] key_buffer = key_buffer.view(-1, num_heads, head_dim) value_buffer = value_buffer.view(-1, num_heads, head_dim) - destindex_copy_kv(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - destindex_copy_kv(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) return # copy key and value calculated in current step to memory manager @@ -669,10 +670,8 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # calcu_shape for context_attention_fwd calcu_shape1 = (-1, self.num_heads, self.head_dim) - - from inference.lightllm.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd - context_attention_fwd(query_states.view(calcu_shape1), + llama_context_attn_fwd(query_states.view(calcu_shape1), key_states.view(calcu_shape1), value_states.view(calcu_shape1), attn_output.view(calcu_shape1), @@ -684,36 +683,16 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # second token and follows # kv = torch.stack((key_states, value_states), dim=2) # (batch_size, seqlen, nheads, headdim) - calcu_shape1 = (-1, self.num_heads, self.head_dim) - att_m_tensor = torch.empty((self.num_heads, infer_info.total_token_num), dtype=query_states.dtype, device="cuda") - - from inference.lightllm.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - from inference.lightllm.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 - from inference.lightllm.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd - - # q*k - token_att_fwd(query_states.view(calcu_shape1), - infer_info.cache_manager.key_buffer[infer_info.decode_layer_id], - att_m_tensor, - infer_info.block_loc, - infer_info.start_loc, - infer_info.seq_len, - infer_info.max_len_in_batch) - - prob = torch.empty_like(att_m_tensor) - token_softmax_fwd(att_m_tensor, infer_info.start_loc, infer_info.seq_len, prob, infer_info.max_len_in_batch) - att_m_tensor = None - attn_output = torch.empty_like(query_states) - - token_att_fwd2(prob, - infer_info.cache_manager.value_buffer[infer_info.decode_layer_id], - attn_output.view(calcu_shape1), - infer_info.block_loc, - infer_info.start_loc, - infer_info.seq_len, - infer_info.max_len_in_batch) - + + token_attention_fwd(query_states, + infer_info.cache_manager.key_buffer[infer_info.decode_layer_id], + infer_info.cache_manager.value_buffer[infer_info.decode_layer_id], + attn_output, + infer_info.block_loc, + infer_info.start_loc, + infer_info.seq_len, + infer_info.max_len_in_batch) attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/tests/test_infer/llama_infer_eigine.py b/tests/test_infer/llama_infer_engine.py similarity index 98% rename from tests/test_infer/llama_infer_eigine.py rename to tests/test_infer/llama_infer_engine.py index 640c29b7e3c9..aa1574c21411 100644 --- a/tests/test_infer/llama_infer_eigine.py +++ b/tests/test_infer/llama_infer_engine.py @@ -12,6 +12,7 @@ GIGABYTE = 1024 ** 3 torch.backends.cudnn.enabled = True +DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 def print_device_memory(): if torch.cuda.is_available(): @@ -116,9 +117,6 @@ def init_and_insert_cache_manager(self): block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") start_loc[i] = i * self.input_len seq_len[i] = self.input_len - print(block_loc.shape) - print(start_loc.shape) - print(seq_len.shape) setattr(self.model.model, 'block_loc', block_loc) setattr(self.model.model, 'start_loc', start_loc) setattr(self.model.model, 'seq_len', seq_len) @@ -140,8 +138,9 @@ def build_model(self): # create new model self.orgin_model = self.model shardconfig = ShardConfig( - tensor_parallel_process_group=self.pg_mesh.get_group_along_axis(2), + tensor_parallel_process_group=self.pg_mesh.get_group_along_axis(TP_DIM), enable_tensor_parallelism=True, + inference_only=True, ) shardformer = ShardFormer(shard_config=shardconfig) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 0103b17c0b71..fdc27aca8306 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -6,7 +6,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from llama_infer_eigine import TPCacheManagerInferenceEngine +from llama_infer_engine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -17,7 +17,7 @@ def run_llama_test(test_config): input_len = 1024 output_len = 128 bs = 8 - engine = TPCacheManagerInferenceEngine(input_len, output_len, bs, 2) + engine = TPCacheManagerInferenceEngine(input_len, output_len, bs, test_config["tp_size"]) engine.generate_data() engine.prepare_model() From 6021b1311934e13e5cddb5e14267b5de9b6fef11 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 24 Aug 2023 20:27:34 +0800 Subject: [PATCH 16/28] Change the parent class of llama policy --- colossalai/shardformer/policies/llama.py | 4 ++-- tests/test_infer/llama_infer_engine.py | 9 +++++---- tests/test_infer/test_llama_infer.py | 7 ++++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 0cc0bd39d05e..01eef5e10d49 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -10,7 +10,7 @@ from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] +__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy', "LlamaForCausalLMPolicy"] class LlamaPolicy(Policy): @@ -265,7 +265,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] -class LlamaModelInferPolicy(LlamaPolicy): +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: super().__init__() diff --git a/tests/test_infer/llama_infer_engine.py b/tests/test_infer/llama_infer_engine.py index aa1574c21411..a123443522c9 100644 --- a/tests/test_infer/llama_infer_engine.py +++ b/tests/test_infer/llama_infer_engine.py @@ -1,11 +1,12 @@ +import copy import torch.distributed as dist import torch import torch.nn as nn from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.inference import MemoryManager -from colossalai.shardformer.policies.llama import LlamaModelInferPolicy, LlamaPolicy +from colossalai.shardformer.policies.llama import LlamaModelInferPolicy, LlamaForCausalLMPolicy from transformers import LlamaForCausalLM, LlamaTokenizer import time from torch.profiler import profile, record_function, ProfilerActivity @@ -136,7 +137,7 @@ def prepare_model(self): def build_model(self): # create new model - self.orgin_model = self.model + self.orgin_model = copy.deepcopy(self.model) shardconfig = ShardConfig( tensor_parallel_process_group=self.pg_mesh.get_group_along_axis(TP_DIM), enable_tensor_parallelism=True, @@ -148,7 +149,7 @@ def build_model(self): self.init_and_insert_cache_manager() policy = LlamaModelInferPolicy() else: - policy = LlamaPolicy() + policy = LlamaForCausalLMPolicy() self.model, _ = shardformer.optimize(self.model, policy) @@ -189,7 +190,7 @@ def run_infer(self, test_origin=True): time_spend = (end-start)/num_tokens_generation times.append(time_spend) - if test_origin: + if not test_origin: model.model.cache_manager.free_all() block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index fdc27aca8306..805b3ef401aa 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -9,9 +9,10 @@ from llama_infer_engine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 2 @parameterize('test_config', [{ - 'tp_size': 2, + 'tp_size': TPSIZE, }]) def run_llama_test(test_config): input_len = 1024 @@ -23,7 +24,7 @@ def run_llama_test(test_config): engine.build_model() - engine.run_infer(test_origin=True) + engine.run_infer(test_origin=False) torch.cuda.empty_cache() @@ -38,7 +39,7 @@ def check_llama(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): - spawn(check_llama, 2) + spawn(check_llama, TPSIZE) if __name__ == "__main__": From 6a1bafaf3d2ea77281800e66ec902b834448770c Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Fri, 25 Aug 2023 14:26:37 +0800 Subject: [PATCH 17/28] add nvtx --- colossalai/shardformer/modeling/llama.py | 4 ++++ tests/test_infer/llama_infer_engine.py | 7 ++++++- tests/test_infer/test_llama_infer.py | 7 ++++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 75f3d79555dd..c7eaf03040af 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,6 +16,7 @@ from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager +import PyNVTX as nvtx class LlamaPipelineForwards: @@ -402,6 +403,7 @@ class LlamaInferenceForwards: """ @staticmethod + @nvtx.annotate("llama_model_forward") def llama_model_forward( self: LlamaModel, input_ids: torch.LongTensor = None, @@ -549,6 +551,7 @@ def llama_model_forward( ) @staticmethod + @nvtx.annotate("llama_decoder_layer_forward") def llama_decoder_layer_forward( self, hidden_states: torch.Tensor, @@ -595,6 +598,7 @@ def llama_decoder_layer_forward( @staticmethod + @nvtx.annotate("llama_flash_attn_kvcache_forward") def llama_flash_attn_kvcache_forward( self: LlamaAttention, hidden_states: torch.Tensor, diff --git a/tests/test_infer/llama_infer_engine.py b/tests/test_infer/llama_infer_engine.py index a123443522c9..18188ef3c3c4 100644 --- a/tests/test_infer/llama_infer_engine.py +++ b/tests/test_infer/llama_infer_engine.py @@ -10,6 +10,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer import time from torch.profiler import profile, record_function, ProfilerActivity +import PyNVTX as nvtx GIGABYTE = 1024 ** 3 torch.backends.cudnn.enabled = True @@ -146,9 +147,11 @@ def build_model(self): shardformer = ShardFormer(shard_config=shardconfig) if self.bs >= 4: + self.use_cache_manager = True self.init_and_insert_cache_manager() policy = LlamaModelInferPolicy() else: + self.use_cache_manager = False policy = LlamaForCausalLMPolicy() self.model, _ = shardformer.optimize(self.model, policy) @@ -178,8 +181,10 @@ def run_infer(self, test_origin=True): for i in range(iters): torch.cuda.synchronize() start = time.time() + nvtx.RangePushA("generate") outputs = model.generate(**self.input_tokens, **generate_kwargs, early_stopping=False) + nvtx.RangePop() outputs_list.append(outputs) torch.cuda.synchronize() @@ -190,7 +195,7 @@ def run_infer(self, test_origin=True): time_spend = (end-start)/num_tokens_generation times.append(time_spend) - if not test_origin: + if not test_origin and self.use_cache_manager: model.model.cache_manager.free_all() block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 805b3ef401aa..dfab2de1f2dc 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -2,6 +2,7 @@ import pytest import torch +import numpy as np import colossalai from colossalai.logging import disable_existing_loggers @@ -9,7 +10,7 @@ from llama_infer_engine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -TPSIZE = 2 +TPSIZE = 1 @parameterize('test_config', [{ 'tp_size': TPSIZE, @@ -24,8 +25,8 @@ def run_llama_test(test_config): engine.build_model() - engine.run_infer(test_origin=False) - + outputs_list = engine.run_infer(test_origin=False) + torch.cuda.empty_cache() From f79308ec19954b67f090394ce85e51fa4346c0b7 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sun, 27 Aug 2023 22:48:28 +0800 Subject: [PATCH 18/28] move llama inference code to tensor_parallel --- colossalai/shardformer/modeling/llama.py | 316 +------------ colossalai/shardformer/policies/llama.py | 33 +- colossalai/tensor_parallel/__init__.py | 0 .../tensor_parallel/tpinference/__init__.py | 34 ++ .../tpinference}/llama_infer_engine.py | 9 +- .../tpinference/modeling/__init__.py | 0 .../tpinference/modeling/llama.py | 424 ++++++++++++++++++ .../tpinference/pollcies/__init__.py | 0 .../tpinference/pollcies/llama.py | 35 ++ tests/test_infer/test_llama_infer.py | 2 +- 10 files changed, 498 insertions(+), 355 deletions(-) create mode 100644 colossalai/tensor_parallel/__init__.py create mode 100644 colossalai/tensor_parallel/tpinference/__init__.py rename {tests/test_infer => colossalai/tensor_parallel/tpinference}/llama_infer_engine.py (96%) create mode 100644 colossalai/tensor_parallel/tpinference/modeling/__init__.py create mode 100644 colossalai/tensor_parallel/tpinference/modeling/llama.py create mode 100644 colossalai/tensor_parallel/tpinference/pollcies/__init__.py create mode 100644 colossalai/tensor_parallel/tpinference/pollcies/llama.py diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index c7eaf03040af..294ab87709c6 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,22 +1,16 @@ from typing import Callable, List, Optional, Tuple import torch -import numpy as np from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm, LlamaAttention +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm from transformers.utils import logging -from colossalai.shardformer.inference import BatchInferState -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -from colossalai.kernel.triton.context_attention import llama_context_attn_fwd -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager -import PyNVTX as nvtx class LlamaPipelineForwards: @@ -397,314 +391,6 @@ def llama_for_sequence_classification_forward( return {'hidden_states': hidden_states} -class LlamaInferenceForwards: - """ - This class holds forwards for llama inference. - """ - - @staticmethod - @nvtx.annotate("llama_model_forward") - def llama_model_forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - batch_size = input_ids.shape[0] # input_ids.shape[0] - - infer_info = BatchInferState(batch_size, input_ids.shape[1]) - infer_info.batch_size = batch_size - # NOTE: dummy implementation here for testing, just assume all inputs same length - infer_info.block_loc = self.block_loc - infer_info.start_loc = self.start_loc - infer_info.seq_len = self.seq_len - infer_info.max_len_in_batch = self.max_len_in_batch - - b_seq_len_numpy = infer_info.seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - - # this equals - infer_info.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) - infer_info.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - # TODO dummy but work, revise it - past_key_values_length = self.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - infer_info.set_cache_manager(self.cache_manager) - - # FIXME: differentiate with prefill stage - # block_loc require different value-assigning method for two different stage - if use_cache and seq_length != 1: - # NOTE assuem prefill stage - # allocate memory block - infer_info.is_context_stage = True # set prefill stage, notify attention layer - infer_info.context_mem_index = infer_info.cache_manager.alloc(infer_info.total_token_num) - infer_info.init_block_loc(infer_info.block_loc, infer_info.seq_len, seq_length, infer_info.context_mem_index) - else: - # TODO handle the condition that no contiguous memory presents - alloc_mem = self.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_info.decode_mem_index = alloc_mem[0] - infer_info.decode_mem_start = alloc_mem[1] - infer_info.decode_mem_end = alloc_mem[2] - infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print(f" infer_info.cache_manager.past_key_values_length: {infer_info.cache_manager.past_key_values_length}") - infer_info.decode_is_contiguous = False - alloc_mem = self.cache_manager.alloc(batch_size) - infer_info.decode_mem_index = alloc_mem - # infer_info.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_info.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - infer_info.decode_layer_id = 0 - - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] if past_key_values is not None else None - # NOTE: modify here for passing args to decoder layer - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - infer_info=infer_info, - ) - infer_info.decode_layer_id += 1 - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - hidden_states = self.norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None - - # update indices - self.max_len_in_batch += 1 - self.block_loc[:, self.max_len_in_batch-1] = self.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - self.start_loc = self.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - self.total_token_num += batch_size - self.seq_len += 1 - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - @staticmethod - @nvtx.annotate("llama_decoder_layer_forward") - def llama_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - infer_info: Optional[BatchInferState] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - infer_info=infer_info, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - - @staticmethod - @nvtx.annotate("llama_flash_attn_kvcache_forward") - def llama_flash_attn_kvcache_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - infer_info: Optional[BatchInferState] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - assert use_cache is True, "use_cache should be set to True using this llama attention" - - bsz, q_len, _ = hidden_states.size() - - # TODO might think about better way to handle transposed k and v - # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] - # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states_transposed = key_states.transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - - # cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len) - cos ,sin = infer_info.position_cos, infer_info.position_sin - - cos_sin_cache = torch.cat((cos, sin), dim=-1) - - from col_pos_encoding_ops import rotary_embedding_neox - - rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - num_heads = key_buffer.shape[2] - head_dim = key_buffer.shape[3] - key_buffer = key_buffer.view(-1, num_heads, head_dim) - value_buffer = value_buffer.view(-1, num_heads, head_dim) - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return - - # copy key and value calculated in current step to memory manager - if infer_info.is_context_stage: - _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.context_mem_index, infer_info.cache_manager) - else: - _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.decode_mem_index, infer_info.cache_manager) - - # this is worse than destcopy - # torch.Tensor.copy_(infer_info.cache_manager.key_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],key_states) - # torch.Tensor.copy_(infer_info.cache_manager.value_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],value_states) - - # FIXME might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - if infer_info.decode_layer_id == 0: # once per model.forward - infer_info.cache_manager.past_key_values_length += q_len # seq_len - - query_states = query_states.transpose(1, 2) - - if infer_info.is_context_stage: - # first token generation - - # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, - # key_states, - # value_states, - # 0, - # 1/math.sqrt(self.head_dim), - # causal, - # False) - - attn_output = torch.empty_like(query_states) - - # calcu_shape for context_attention_fwd - calcu_shape1 = (-1, self.num_heads, self.head_dim) - - llama_context_attn_fwd(query_states.view(calcu_shape1), - key_states.view(calcu_shape1), - value_states.view(calcu_shape1), - attn_output.view(calcu_shape1), - infer_info.start_loc, - infer_info.seq_len, - infer_info.max_len_in_batch) - - else: - # second token and follows - # kv = torch.stack((key_states, value_states), dim=2) - # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_states) - - token_attention_fwd(query_states, - infer_info.cache_manager.key_buffer[infer_info.decode_layer_id], - infer_info.cache_manager.value_buffer[infer_info.decode_layer_id], - attn_output, - infer_info.block_loc, - infer_info.start_loc, - infer_info.seq_len, - infer_info.max_len_in_batch) - - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - # return past_key_value as None - return attn_output, None, None - def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 01eef5e10d49..9f9f142cae97 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -7,7 +7,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from ..modeling.llama import LlamaInferenceForwards, LlamaPipelineForwards, get_llama_flash_attention_forward +from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy', "LlamaForCausalLMPolicy"] @@ -263,34 +263,3 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: """No shared params in llama for sequence classification model""" return [] - - -class LlamaModelInferPolicy(LlamaForCausalLMPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - policy = super().module_policy() - self.shard_config._infer() - - # example for replace layer or decoder - # if self.shard_config.enable_flash_attention: - # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ - # 'forward': get_llama_flash_attention_forward(), - # }) - - infer_forward = LlamaInferenceForwards.llama_model_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - - infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer) - - infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) - - return policy diff --git a/colossalai/tensor_parallel/__init__.py b/colossalai/tensor_parallel/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/tensor_parallel/tpinference/__init__.py b/colossalai/tensor_parallel/tpinference/__init__.py new file mode 100644 index 000000000000..64533cdde731 --- /dev/null +++ b/colossalai/tensor_parallel/tpinference/__init__.py @@ -0,0 +1,34 @@ +from functools import partial + +from .modeling.llama import LlamaInferenceForwards +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + policy = super().module_policy() + self.shard_config._infer() + + # example for replace layer or decoder + # if self.shard_config.enable_flash_attention: + # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + # 'forward': get_llama_flash_attention_forward(), + # }) + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) + + return policy \ No newline at end of file diff --git a/tests/test_infer/llama_infer_engine.py b/colossalai/tensor_parallel/tpinference/llama_infer_engine.py similarity index 96% rename from tests/test_infer/llama_infer_engine.py rename to colossalai/tensor_parallel/tpinference/llama_infer_engine.py index 18188ef3c3c4..e2cc61de8e2e 100644 --- a/tests/test_infer/llama_infer_engine.py +++ b/colossalai/tensor_parallel/tpinference/llama_infer_engine.py @@ -1,16 +1,13 @@ import copy -import torch.distributed as dist import torch -import torch.nn as nn from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.inference import MemoryManager -from colossalai.shardformer.policies.llama import LlamaModelInferPolicy, LlamaForCausalLMPolicy +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy +from colossalai.tensor_parallel.tpinference.pollcies.llama import LlamaModelInferPolicy from transformers import LlamaForCausalLM, LlamaTokenizer import time -from torch.profiler import profile, record_function, ProfilerActivity -import PyNVTX as nvtx GIGABYTE = 1024 ** 3 torch.backends.cudnn.enabled = True @@ -181,10 +178,8 @@ def run_infer(self, test_origin=True): for i in range(iters): torch.cuda.synchronize() start = time.time() - nvtx.RangePushA("generate") outputs = model.generate(**self.input_tokens, **generate_kwargs, early_stopping=False) - nvtx.RangePop() outputs_list.append(outputs) torch.cuda.synchronize() diff --git a/colossalai/tensor_parallel/tpinference/modeling/__init__.py b/colossalai/tensor_parallel/tpinference/modeling/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/tensor_parallel/tpinference/modeling/llama.py b/colossalai/tensor_parallel/tpinference/modeling/llama.py new file mode 100644 index 000000000000..f91e9dc4b770 --- /dev/null +++ b/colossalai/tensor_parallel/tpinference/modeling/llama.py @@ -0,0 +1,424 @@ +from typing import List, Optional, Tuple + +import torch +import numpy as np +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.models.llama.modeling_llama import LlamaModel, LlamaRMSNorm, LlamaAttention +from colossalai.shardformer.inference import BatchInferState +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.context_attention import llama_context_attn_fwd +from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + +class LlamaInferenceForwards: + """ + This class holds forwards for llama inference. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + batch_size = input_ids.shape[0] # input_ids.shape[0] + + infer_info = BatchInferState(batch_size, input_ids.shape[1]) + infer_info.batch_size = batch_size + # NOTE: dummy implementation here for testing, just assume all inputs same length + infer_info.block_loc = self.block_loc + infer_info.start_loc = self.start_loc + infer_info.seq_len = self.seq_len + infer_info.max_len_in_batch = self.max_len_in_batch + + b_seq_len_numpy = infer_info.seq_len.cpu().numpy() + position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) + for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + + # this equals + infer_info.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) + infer_info.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + # TODO dummy but work, revise it + past_key_values_length = self.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + infer_info.set_cache_manager(self.cache_manager) + + # FIXME: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assuem prefill stage + # allocate memory block + infer_info.is_context_stage = True # set prefill stage, notify attention layer + infer_info.context_mem_index = infer_info.cache_manager.alloc(infer_info.total_token_num) + infer_info.init_block_loc(infer_info.block_loc, infer_info.seq_len, seq_length, infer_info.context_mem_index) + else: + # TODO handle the condition that no contiguous memory presents + alloc_mem = self.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_info.decode_mem_index = alloc_mem[0] + infer_info.decode_mem_start = alloc_mem[1] + infer_info.decode_mem_end = alloc_mem[2] + infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print(f" infer_info.cache_manager.past_key_values_length: {infer_info.cache_manager.past_key_values_length}") + infer_info.decode_is_contiguous = False + alloc_mem = self.cache_manager.alloc(batch_size) + infer_info.decode_mem_index = alloc_mem + # infer_info.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_info.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + infer_info.decode_layer_id = 0 + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_info=infer_info, + ) + infer_info.decode_layer_id += 1 + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + + # update indices + self.max_len_in_batch += 1 + self.block_loc[:, self.max_len_in_batch-1] = self.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + self.start_loc = self.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + self.total_token_num += batch_size + self.seq_len += 1 + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + @staticmethod + def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_info: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_info=infer_info, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_info: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # TODO might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states_transposed = key_states.transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + # cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len) + cos ,sin = infer_info.position_cos, infer_info.position_sin + + cos_sin_cache = torch.cat((cos, sin), dim=-1) + + from col_pos_encoding_ops import rotary_embedding_neox + + rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + num_heads = key_buffer.shape[2] + head_dim = key_buffer.shape[3] + key_buffer = key_buffer.view(-1, num_heads, head_dim) + value_buffer = value_buffer.view(-1, num_heads, head_dim) + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + # copy key and value calculated in current step to memory manager + if infer_info.is_context_stage: + _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.context_mem_index, infer_info.cache_manager) + else: + _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.decode_mem_index, infer_info.cache_manager) + + # this is worse than destcopy + # torch.Tensor.copy_(infer_info.cache_manager.key_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],key_states) + # torch.Tensor.copy_(infer_info.cache_manager.value_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],value_states) + + # FIXME might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_info.decode_layer_id == 0: # once per model.forward + infer_info.cache_manager.past_key_values_length += q_len # seq_len + + query_states = query_states.transpose(1, 2) + + if infer_info.is_context_stage: + # first token generation + + # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, + # key_states, + # value_states, + # 0, + # 1/math.sqrt(self.head_dim), + # causal, + # False) + + attn_output = torch.empty_like(query_states) + + # calcu_shape for context_attention_fwd + calcu_shape1 = (-1, self.num_heads, self.head_dim) + + llama_context_attn_fwd(query_states.view(calcu_shape1), + key_states.view(calcu_shape1), + value_states.view(calcu_shape1), + attn_output.view(calcu_shape1), + infer_info.start_loc, + infer_info.seq_len, + infer_info.max_len_in_batch) + + else: + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + token_attention_fwd(query_states, + infer_info.cache_manager.key_buffer[infer_info.decode_layer_id], + infer_info.cache_manager.value_buffer[infer_info.decode_layer_id], + attn_output, + infer_info.block_loc, + infer_info.start_loc, + infer_info.seq_len, + infer_info.max_len_in_batch) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None + +def get_llama_flash_attention_forward(): + + from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True + except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") + HAS_VLLM_KERNERL = False + + + def forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + if HAS_VLLM_KERNERL: + cos_sin_cache = torch.cat((cos, sin), dim=-1) + rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) + key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) + value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + + flash_attention_mask = None + attn_mask_type = AttnMaskType.causal + if attention_mask != None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) + attn_output = attention(query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + return forward + + +def get_llama_vllm_rmsnorm_forward(): + try: + from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True + except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") + HAS_VLLM_KERNERL = False + + if HAS_VLLM_KERNERL: + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None diff --git a/colossalai/tensor_parallel/tpinference/pollcies/__init__.py b/colossalai/tensor_parallel/tpinference/pollcies/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/tensor_parallel/tpinference/pollcies/llama.py b/colossalai/tensor_parallel/tpinference/pollcies/llama.py new file mode 100644 index 000000000000..570e10ba3010 --- /dev/null +++ b/colossalai/tensor_parallel/tpinference/pollcies/llama.py @@ -0,0 +1,35 @@ +from functools import partial + +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling.llama import LlamaInferenceForwards + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + policy = super().module_policy() + self.shard_config._infer() + + # example for replace layer or decoder + # if self.shard_config.enable_flash_attention: + # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + # 'forward': get_llama_flash_attention_forward(), + # }) + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) + + return policy \ No newline at end of file diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index dfab2de1f2dc..2e8b2aecd1f2 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -7,7 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from llama_infer_engine import TPCacheManagerInferenceEngine +from colossalai.tensor_parallel.tpinference.llama_infer_engine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 1 From 2a6a380cc77db34d68909c9acbca9f74b909c680 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 28 Aug 2023 10:42:18 +0800 Subject: [PATCH 19/28] fix __init__.py --- .../tensor_parallel/tpinference/__init__.py | 36 ++----------------- .../tpinference/modeling/__init__.py | 3 ++ .../tpinference/pollcies/__init__.py | 3 ++ 3 files changed, 9 insertions(+), 33 deletions(-) diff --git a/colossalai/tensor_parallel/tpinference/__init__.py b/colossalai/tensor_parallel/tpinference/__init__.py index 64533cdde731..8284818c594a 100644 --- a/colossalai/tensor_parallel/tpinference/__init__.py +++ b/colossalai/tensor_parallel/tpinference/__init__.py @@ -1,34 +1,4 @@ -from functools import partial - from .modeling.llama import LlamaInferenceForwards -from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy - -class LlamaModelInferPolicy(LlamaForCausalLMPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - policy = super().module_policy() - self.shard_config._infer() - - # example for replace layer or decoder - # if self.shard_config.enable_flash_attention: - # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ - # 'forward': get_llama_flash_attention_forward(), - # }) - - infer_forward = LlamaInferenceForwards.llama_model_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - - infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer) - - infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) - - return policy \ No newline at end of file +from .pollcies.llama import LlamaModelInferPolicy + +__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy'] \ No newline at end of file diff --git a/colossalai/tensor_parallel/tpinference/modeling/__init__.py b/colossalai/tensor_parallel/tpinference/modeling/__init__.py index e69de29bb2d1..1b022f38c470 100644 --- a/colossalai/tensor_parallel/tpinference/modeling/__init__.py +++ b/colossalai/tensor_parallel/tpinference/modeling/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaInferenceForwards + +__all__ = ['LlamaInferenceForwards'] \ No newline at end of file diff --git a/colossalai/tensor_parallel/tpinference/pollcies/__init__.py b/colossalai/tensor_parallel/tpinference/pollcies/__init__.py index e69de29bb2d1..d92a3e84d097 100644 --- a/colossalai/tensor_parallel/tpinference/pollcies/__init__.py +++ b/colossalai/tensor_parallel/tpinference/pollcies/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaModelInferPolicy + +__all__ = ['LlamaModelInferPolicy'] \ No newline at end of file From d10dcf422fc794bc36800cd8e2f537513e1d31ee Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 28 Aug 2023 11:35:51 +0800 Subject: [PATCH 20/28] rm tensor_parallel --- colossalai/tensor_parallel/__init__.py | 0 colossalai/tensor_parallel/tpinference/__init__.py | 4 ---- colossalai/tpinference/__init__.py | 5 +++++ .../{tensor_parallel => }/tpinference/llama_infer_engine.py | 2 +- .../{tensor_parallel => }/tpinference/modeling/__init__.py | 0 .../{tensor_parallel => }/tpinference/modeling/llama.py | 0 .../{tensor_parallel => }/tpinference/pollcies/__init__.py | 0 .../{tensor_parallel => }/tpinference/pollcies/llama.py | 0 tests/test_infer/test_llama_infer.py | 2 +- 9 files changed, 7 insertions(+), 6 deletions(-) delete mode 100644 colossalai/tensor_parallel/__init__.py delete mode 100644 colossalai/tensor_parallel/tpinference/__init__.py create mode 100644 colossalai/tpinference/__init__.py rename colossalai/{tensor_parallel => }/tpinference/llama_infer_engine.py (99%) rename colossalai/{tensor_parallel => }/tpinference/modeling/__init__.py (100%) rename colossalai/{tensor_parallel => }/tpinference/modeling/llama.py (100%) rename colossalai/{tensor_parallel => }/tpinference/pollcies/__init__.py (100%) rename colossalai/{tensor_parallel => }/tpinference/pollcies/llama.py (100%) diff --git a/colossalai/tensor_parallel/__init__.py b/colossalai/tensor_parallel/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/colossalai/tensor_parallel/tpinference/__init__.py b/colossalai/tensor_parallel/tpinference/__init__.py deleted file mode 100644 index 8284818c594a..000000000000 --- a/colossalai/tensor_parallel/tpinference/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .modeling.llama import LlamaInferenceForwards -from .pollcies.llama import LlamaModelInferPolicy - -__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy'] \ No newline at end of file diff --git a/colossalai/tpinference/__init__.py b/colossalai/tpinference/__init__.py new file mode 100644 index 000000000000..fa50a1a89f35 --- /dev/null +++ b/colossalai/tpinference/__init__.py @@ -0,0 +1,5 @@ +from .modeling.llama import LlamaInferenceForwards +from .pollcies.llama import LlamaModelInferPolicy +from .llama_infer_engine import TPCacheManagerInferenceEngine + +__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'TPCacheManagerInferenceEngine'] \ No newline at end of file diff --git a/colossalai/tensor_parallel/tpinference/llama_infer_engine.py b/colossalai/tpinference/llama_infer_engine.py similarity index 99% rename from colossalai/tensor_parallel/tpinference/llama_infer_engine.py rename to colossalai/tpinference/llama_infer_engine.py index e2cc61de8e2e..963bf064a23f 100644 --- a/colossalai/tensor_parallel/tpinference/llama_infer_engine.py +++ b/colossalai/tpinference/llama_infer_engine.py @@ -5,7 +5,7 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.inference import MemoryManager from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -from colossalai.tensor_parallel.tpinference.pollcies.llama import LlamaModelInferPolicy +from colossalai.tpinference.pollcies.llama import LlamaModelInferPolicy from transformers import LlamaForCausalLM, LlamaTokenizer import time diff --git a/colossalai/tensor_parallel/tpinference/modeling/__init__.py b/colossalai/tpinference/modeling/__init__.py similarity index 100% rename from colossalai/tensor_parallel/tpinference/modeling/__init__.py rename to colossalai/tpinference/modeling/__init__.py diff --git a/colossalai/tensor_parallel/tpinference/modeling/llama.py b/colossalai/tpinference/modeling/llama.py similarity index 100% rename from colossalai/tensor_parallel/tpinference/modeling/llama.py rename to colossalai/tpinference/modeling/llama.py diff --git a/colossalai/tensor_parallel/tpinference/pollcies/__init__.py b/colossalai/tpinference/pollcies/__init__.py similarity index 100% rename from colossalai/tensor_parallel/tpinference/pollcies/__init__.py rename to colossalai/tpinference/pollcies/__init__.py diff --git a/colossalai/tensor_parallel/tpinference/pollcies/llama.py b/colossalai/tpinference/pollcies/llama.py similarity index 100% rename from colossalai/tensor_parallel/tpinference/pollcies/llama.py rename to colossalai/tpinference/pollcies/llama.py diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 2e8b2aecd1f2..ad0ddb894aa4 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -7,7 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.tensor_parallel.tpinference.llama_infer_engine import TPCacheManagerInferenceEngine +from colossalai.tpinference.llama_infer_engine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 1 From fb2603b0139d5ca1093b2bc311390ba75cb480d8 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 28 Aug 2023 13:05:46 +0800 Subject: [PATCH 21/28] fix: fix bugs in auto_policy.py --- colossalai/shardformer/policies/auto_policy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 43ea1c5ab7f6..f270d0d5efb4 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -134,7 +134,9 @@ class PolicyLocation: _INFER_POLICY_LIST = { # LlaMa "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy") + PolicyLocation(file_name="colossalai.tpinference.pollcies.llama", class_name="LlamaModelInferPolicy"), + "transformers.models.llama.modeling_llama.LlamaForCausalLMPolicy": + PolicyLocation(file_name="colossalai.tpinference.pollcies.llama", class_name="LlamaModelInferPolicy"), } From 92fd955ac89947faf1dfb5e8908dfa4f2f0b28c0 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 28 Aug 2023 13:17:20 +0800 Subject: [PATCH 22/28] fix:rm some unused codes --- colossalai/shardformer/policies/llama.py | 2 +- colossalai/tpinference/modeling/llama.py | 107 +---------------------- 2 files changed, 2 insertions(+), 107 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 9f9f142cae97..5ee95f3be8fa 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -10,7 +10,7 @@ from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy', "LlamaForCausalLMPolicy"] +__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] class LlamaPolicy(Policy): diff --git a/colossalai/tpinference/modeling/llama.py b/colossalai/tpinference/modeling/llama.py index f91e9dc4b770..f6bb86714746 100644 --- a/colossalai/tpinference/modeling/llama.py +++ b/colossalai/tpinference/modeling/llama.py @@ -316,109 +316,4 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # return past_key_value as None return attn_output, None, None -def get_llama_flash_attention_forward(): - - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - - try: - from vllm import pos_encoding_ops - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True - except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") - HAS_VLLM_KERNERL = False - - - def forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - if HAS_VLLM_KERNERL: - cos_sin_cache = torch.cat((cos, sin), dim=-1) - rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) - - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention(query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - return forward - - -def get_llama_vllm_rmsnorm_forward(): - try: - from vllm import layernorm_ops - rms_norm = layernorm_ops.rms_norm - HAS_VLLM_KERNERL = True - except: - print("please install vllm kernels to install rmsnorm") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") - HAS_VLLM_KERNERL = False - - if HAS_VLLM_KERNERL: - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None + \ No newline at end of file From c7472492d6eed87b6441344d0517506172866d15 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 28 Aug 2023 14:48:32 +0800 Subject: [PATCH 23/28] mv colossalai/tpinference to colossalai/inference/tensor_parallel --- colossalai/inference/__init__.py | 0 .../tensor_parallel}/__init__.py | 0 .../tensor_parallel}/llama_infer_engine.py | 5 +++-- .../tensor_parallel}/modeling/__init__.py | 0 .../tensor_parallel}/modeling/llama.py | 0 .../tensor_parallel}/pollcies/__init__.py | 0 .../tensor_parallel}/pollcies/llama.py | 0 colossalai/shardformer/policies/auto_policy.py | 16 ++++++++++------ tests/test_infer/test_llama_infer.py | 2 +- 9 files changed, 14 insertions(+), 9 deletions(-) create mode 100644 colossalai/inference/__init__.py rename colossalai/{tpinference => inference/tensor_parallel}/__init__.py (100%) rename colossalai/{tpinference => inference/tensor_parallel}/llama_infer_engine.py (97%) rename colossalai/{tpinference => inference/tensor_parallel}/modeling/__init__.py (100%) rename colossalai/{tpinference => inference/tensor_parallel}/modeling/llama.py (100%) rename colossalai/{tpinference => inference/tensor_parallel}/pollcies/__init__.py (100%) rename colossalai/{tpinference => inference/tensor_parallel}/pollcies/llama.py (100%) diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/tpinference/__init__.py b/colossalai/inference/tensor_parallel/__init__.py similarity index 100% rename from colossalai/tpinference/__init__.py rename to colossalai/inference/tensor_parallel/__init__.py diff --git a/colossalai/tpinference/llama_infer_engine.py b/colossalai/inference/tensor_parallel/llama_infer_engine.py similarity index 97% rename from colossalai/tpinference/llama_infer_engine.py rename to colossalai/inference/tensor_parallel/llama_infer_engine.py index 963bf064a23f..8ceb6d76bb9d 100644 --- a/colossalai/tpinference/llama_infer_engine.py +++ b/colossalai/inference/tensor_parallel/llama_infer_engine.py @@ -5,7 +5,8 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.inference import MemoryManager from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -from colossalai.tpinference.pollcies.llama import LlamaModelInferPolicy +from colossalai.inference.tensor_parallel.pollcies.llama import LlamaModelInferPolicy +from colossalai.shardformer.policies.auto_policy import get_autopolicy from transformers import LlamaForCausalLM, LlamaTokenizer import time @@ -146,7 +147,7 @@ def build_model(self): if self.bs >= 4: self.use_cache_manager = True self.init_and_insert_cache_manager() - policy = LlamaModelInferPolicy() + policy = get_autopolicy(self.model, inference_only=True) else: self.use_cache_manager = False policy = LlamaForCausalLMPolicy() diff --git a/colossalai/tpinference/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py similarity index 100% rename from colossalai/tpinference/modeling/__init__.py rename to colossalai/inference/tensor_parallel/modeling/__init__.py diff --git a/colossalai/tpinference/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py similarity index 100% rename from colossalai/tpinference/modeling/llama.py rename to colossalai/inference/tensor_parallel/modeling/llama.py diff --git a/colossalai/tpinference/pollcies/__init__.py b/colossalai/inference/tensor_parallel/pollcies/__init__.py similarity index 100% rename from colossalai/tpinference/pollcies/__init__.py rename to colossalai/inference/tensor_parallel/pollcies/__init__.py diff --git a/colossalai/tpinference/pollcies/llama.py b/colossalai/inference/tensor_parallel/pollcies/llama.py similarity index 100% rename from colossalai/tpinference/pollcies/llama.py rename to colossalai/inference/tensor_parallel/pollcies/llama.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index f270d0d5efb4..aa100a0652ef 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -134,17 +134,21 @@ class PolicyLocation: _INFER_POLICY_LIST = { # LlaMa "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="colossalai.tpinference.pollcies.llama", class_name="LlamaModelInferPolicy"), - "transformers.models.llama.modeling_llama.LlamaForCausalLMPolicy": - PolicyLocation(file_name="colossalai.tpinference.pollcies.llama", class_name="LlamaModelInferPolicy"), + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), } -def import_policy(policy_location: PolicyLocation) -> Policy: +def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy: """ Dynamically import a Policy class based on the policy location. """ - module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" + + if inference_only: + module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}" + else: + module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" module = importlib.import_module(module_name) return getattr(module, policy_location.class_name) @@ -181,5 +185,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" ) else: - policy = import_policy(policy_location) + policy = import_policy(policy_location, inference_only) return policy() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index ad0ddb894aa4..f9c8d2383f10 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -7,7 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.tpinference.llama_infer_engine import TPCacheManagerInferenceEngine +from colossalai.inference.tensor_parallel.llama_infer_engine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 1 From c27088f28861905c0813c9833f48a81d3e325d59 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 30 Aug 2023 11:03:09 +0800 Subject: [PATCH 24/28] change __init__.py --- colossalai/inference/tensor_parallel/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py index fa50a1a89f35..30a7a58791a4 100644 --- a/colossalai/inference/tensor_parallel/__init__.py +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -1,5 +1,6 @@ from .modeling.llama import LlamaInferenceForwards from .pollcies.llama import LlamaModelInferPolicy -from .llama_infer_engine import TPCacheManagerInferenceEngine +from .engine import TPInferEngine +from .kvcache_manager import MemoryManager -__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'TPCacheManagerInferenceEngine'] \ No newline at end of file +__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine'] \ No newline at end of file From af160406ea294690917d4936cb50b11086e03524 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 30 Aug 2023 11:06:54 +0800 Subject: [PATCH 25/28] save change --- .../tensor_parallel/modeling/llama.py | 124 +++++++++--------- tests/test_infer/test_llama_infer.py | 2 +- 2 files changed, 64 insertions(+), 62 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index f6bb86714746..cfef5084e8c2 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,17 +5,19 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaModel, LlamaRMSNorm, LlamaAttention +from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention from colossalai.shardformer.inference import BatchInferState from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.context_attention import llama_context_attn_fwd from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from typing import List, Optional, Tuple +from transformers.modeling_outputs import BaseModelOutputWithPast class LlamaInferenceForwards: """ This class holds forwards for llama inference. """ - + @staticmethod def llama_model_forward( self: LlamaModel, @@ -29,23 +31,25 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): + batch_size = input_ids.shape[0] # input_ids.shape[0] - infer_info = BatchInferState(batch_size, input_ids.shape[1]) - infer_info.batch_size = batch_size - # NOTE: dummy implementation here for testing, just assume all inputs same length - infer_info.block_loc = self.block_loc - infer_info.start_loc = self.start_loc - infer_info.seq_len = self.seq_len - infer_info.max_len_in_batch = self.max_len_in_batch + # infer_state = BatchInferState(batch_size, input_ids.shape[1]) + # infer_state.batch_size = batch_size + # # NOTE: dummy implementation here for testing, just assume all inputs same length + # infer_state.block_loc = self.block_loc + # infer_state.start_loc = self.start_loc + # infer_state.seq_len = self.seq_len + # infer_state.max_len_in_batch = self.max_len_in_batch - b_seq_len_numpy = infer_info.seq_len.cpu().numpy() + infer_state = self.infer_state + b_seq_len_numpy = infer_state.seq_len.cpu().numpy() position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() # this equals - infer_info.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) - infer_info.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -64,37 +68,36 @@ def llama_model_forward( if past_key_values is not None: # TODO dummy but work, revise it - past_key_values_length = self.cache_manager.past_key_values_length + past_key_values_length = infer_state.cache_manager.past_key_values_length # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - - infer_info.set_cache_manager(self.cache_manager) # FIXME: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: # NOTE assuem prefill stage # allocate memory block - infer_info.is_context_stage = True # set prefill stage, notify attention layer - infer_info.context_mem_index = infer_info.cache_manager.alloc(infer_info.total_token_num) - infer_info.init_block_loc(infer_info.block_loc, infer_info.seq_len, seq_length, infer_info.context_mem_index) + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) else: # TODO handle the condition that no contiguous memory presents - alloc_mem = self.cache_manager.alloc_contiguous(batch_size) + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: - infer_info.decode_mem_index = alloc_mem[0] - infer_info.decode_mem_start = alloc_mem[1] - infer_info.decode_mem_end = alloc_mem[2] - infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print(f" infer_info.cache_manager.past_key_values_length: {infer_info.cache_manager.past_key_values_length}") - infer_info.decode_is_contiguous = False - alloc_mem = self.cache_manager.alloc(batch_size) - infer_info.decode_mem_index = alloc_mem - # infer_info.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_info.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_info.block_loc[:, seq_length_with_past - 1] = infer_info.decode_mem_index + print(f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}") + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -113,6 +116,7 @@ def llama_model_forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) + attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) @@ -124,7 +128,7 @@ def llama_model_forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - infer_info.decode_layer_id = 0 + infer_state.decode_layer_id = 0 for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None @@ -136,9 +140,9 @@ def llama_model_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - infer_info=infer_info, + infer_state=infer_state, ) - infer_info.decode_layer_id += 1 + infer_state.decode_layer_id += 1 hidden_states = layer_outputs[0] if use_cache: @@ -148,14 +152,13 @@ def llama_model_forward( next_cache = next_decoder_cache if use_cache else None # update indices - self.max_len_in_batch += 1 - self.block_loc[:, self.max_len_in_batch-1] = self.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - self.start_loc = self.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - self.total_token_num += batch_size - self.seq_len += 1 + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -172,7 +175,7 @@ def llama_decoder_layer_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - infer_info: Optional[BatchInferState] = None, + infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -187,7 +190,7 @@ def llama_decoder_layer_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - infer_info=infer_info, + infer_state=infer_state, ) hidden_states = residual + hidden_states @@ -218,7 +221,7 @@ def llama_flash_attn_kvcache_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - infer_info: Optional[BatchInferState] = None, + infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: assert use_cache is True, "use_cache should be set to True using this llama attention" @@ -235,11 +238,11 @@ def llama_flash_attn_kvcache_forward( value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) # cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len) - cos ,sin = infer_info.position_cos, infer_info.position_sin + cos ,sin = infer_state.position_cos, infer_state.position_sin cos_sin_cache = torch.cat((cos, sin), dim=-1) - from col_pos_encoding_ops import rotary_embedding_neox + from vllm.pos_encoding_ops import rotary_embedding_neox rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) @@ -253,24 +256,24 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, return # copy key and value calculated in current step to memory manager - if infer_info.is_context_stage: - _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.context_mem_index, infer_info.cache_manager) + if infer_state.is_context_stage: + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, infer_state.cache_manager) else: - _copy_kv_to_mem_cache(infer_info.decode_layer_id, key_states, value_states, infer_info.decode_mem_index, infer_info.cache_manager) + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager) # this is worse than destcopy - # torch.Tensor.copy_(infer_info.cache_manager.key_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],key_states) - # torch.Tensor.copy_(infer_info.cache_manager.value_buffer[infer_info.decode_layer_id][infer_info.decode_mem_start:infer_info.decode_mem_end, :, :],value_states) + # torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states) + # torch.Tensor.copy_(infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],value_states) # FIXME might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_info.decode_layer_id == 0: # once per model.forward - infer_info.cache_manager.past_key_values_length += q_len # seq_len + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len query_states = query_states.transpose(1, 2) - if infer_info.is_context_stage: + if infer_state.is_context_stage: # first token generation # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, @@ -290,10 +293,9 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, key_states.view(calcu_shape1), value_states.view(calcu_shape1), attn_output.view(calcu_shape1), - infer_info.start_loc, - infer_info.seq_len, - infer_info.max_len_in_batch) - + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length) else: # second token and follows # kv = torch.stack((key_states, value_states), dim=2) @@ -301,13 +303,13 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, attn_output = torch.empty_like(query_states) token_attention_fwd(query_states, - infer_info.cache_manager.key_buffer[infer_info.decode_layer_id], - infer_info.cache_manager.value_buffer[infer_info.decode_layer_id], + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, - infer_info.block_loc, - infer_info.start_loc, - infer_info.seq_len, - infer_info.max_len_in_batch) + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length) attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index f9c8d2383f10..83da446ef92f 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -10,7 +10,7 @@ from colossalai.inference.tensor_parallel.llama_infer_engine import TPCacheManagerInferenceEngine os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -TPSIZE = 1 +TPSIZE = 2 @parameterize('test_config', [{ 'tp_size': TPSIZE, From f30f5428a286dee97973f23667c41e8d2cccea73 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 30 Aug 2023 11:45:40 +0800 Subject: [PATCH 26/28] fix engine --- .../inference/tensor_parallel/engine.py | 22 ++--- .../tensor_parallel/modeling/llama.py | 2 +- tests/test_infer/test_llama_infer_1.py | 87 +++++++++++++++++++ 3 files changed, 99 insertions(+), 12 deletions(-) create mode 100644 tests/test_infer/test_llama_infer_1.py diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index f643d892aab9..6f2009b43749 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -16,7 +16,7 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM'] +_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] class TPInferEngine: @@ -27,7 +27,7 @@ def __init__(self, max_input_len: int, max_output_len: int, dtype: torch.dtype = torch.float16, - device: torch.device = torch.cuda.current_device()) -> None: + device: str = 'cuda') -> None: self.model = model self.sharded_model = None @@ -40,7 +40,7 @@ def __init__(self, assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint" - self.device = device + torch.device(device=device) self.dtype = dtype self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads @@ -88,7 +88,7 @@ def shard_model_by(self, shardformer: ShardFormer) -> None: assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(self.model, inference_only=True) self.sharded_model, _ = shardformer.optimize(self.model, policy) - self.sharded_model = self.sharded_model.to(self.device) + self.sharded_model = self.sharded_model.cuda() @staticmethod def _supported_models() -> List[str]: @@ -137,7 +137,7 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Te input_tokens = dict(input_ids=input_tokens) for t in input_tokens: if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].to(self.device) + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) @@ -173,8 +173,8 @@ def prepare_batch_state(self, inputs) -> BatchInferState: else: batch_size = inputs.shape[0] - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device) - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device) + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda') + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda') start_index = 0 max_len_in_batch = -1 @@ -197,10 +197,10 @@ def prepare_batch_state(self, inputs) -> BatchInferState: block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, - device=self.device) + device='cuda') batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device - batch_infer_state.start_loc = seq_start_indexes.to(self.device) + batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device + batch_infer_state.start_loc = seq_start_indexes.to('cuda') batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 batch_infer_state.past_key_values_len = 0 @@ -251,4 +251,4 @@ def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: # => put information already recorded in batchinferstate and pass it to model forward # => clear records in engine def add_request(): - raise NotImplementedError() + raise NotImplementedError() \ No newline at end of file diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index cfef5084e8c2..df1b99769d3e 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -6,7 +6,7 @@ BaseModelOutputWithPast, ) from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention -from colossalai.shardformer.inference import BatchInferState +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.context_attention import llama_context_attn_fwd from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd diff --git a/tests/test_infer/test_llama_infer_1.py b/tests/test_infer/test_llama_infer_1.py new file mode 100644 index 000000000000..0f9271376f3e --- /dev/null +++ b/tests/test_infer/test_llama_infer_1.py @@ -0,0 +1,87 @@ +import os + +import pytest +import torch +import numpy as np + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from transformers import LlamaForCausalLM, LlamaTokenizer +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.tensor_parallel.engine import TPInferEngine +import torch.distributed as dist + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config,"max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config,"max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config): + + llama_model_path = "/data/scratch/llama-7b-hf" + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(model.model, base=10000) + model = model.half() + model.to(torch.cuda.current_device()) + + text = "Introduce some landmarks in Beijing" + input_ids = tokenizer.encode(text, return_tensors='pt') + # pg_mesh = ProcessGroupMesh(1, 1, test_config["tp_size"]) + + infer_engine = TPInferEngine(model.half(), 4, 12, 8) + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + + infer_engine.prepare_with_shard_config(shard_config) + infer_engine.shard_model_by(shardformer) + + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate(input_ids, generate_kwargs) + + print("outputs: ", outputs) + + output_text = tokenizer.decode(outputs) + print(output_text) + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() From 4b52ebd514e7c2b537a8d014e7db2858128cd89c Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 30 Aug 2023 11:58:43 +0800 Subject: [PATCH 27/28] Bug fix: Fix hang --- .../inference/tensor_parallel/engine.py | 2 +- tests/test_infer/test_llama_infer.py | 60 ++++++++++--- tests/test_infer/test_llama_infer_1.py | 87 ------------------- 3 files changed, 51 insertions(+), 98 deletions(-) delete mode 100644 tests/test_infer/test_llama_infer_1.py diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 6f2009b43749..e833ef3bdb7e 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -137,7 +137,7 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Te input_tokens = dict(input_ids=input_tokens) for t in input_tokens: if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + input_tokens[t] = input_tokens[t].cuda() outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 83da446ef92f..89646ca9f97f 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -7,27 +7,67 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.inference.tensor_parallel.llama_infer_engine import TPCacheManagerInferenceEngine +from transformers import LlamaForCausalLM, LlamaTokenizer +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.tensor_parallel.engine import TPInferEngine +import torch.distributed as dist os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 2 +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config,"max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config,"max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) def run_llama_test(test_config): - input_len = 1024 - output_len = 128 - bs = 8 - engine = TPCacheManagerInferenceEngine(input_len, output_len, bs, test_config["tp_size"]) - engine.generate_data() - engine.prepare_model() - engine.build_model() + llama_model_path = "/data/scratch/llama-7b-hf" + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(model.model, base=10000) + model = model.half() + model.to(torch.cuda.current_device()) + + text = "Introduce some landmarks in Beijing" + input_ids = tokenizer.encode(text, return_tensors='pt') + # pg_mesh = ProcessGroupMesh(1, 1, test_config["tp_size"]) - outputs_list = engine.run_infer(test_origin=False) + infer_engine = TPInferEngine(model.half(), 4, 12, 8) + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + + infer_engine.prepare_with_shard_config(shard_config) + infer_engine.shard_model_by(shardformer) + + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate(input_ids, generate_kwargs) - torch.cuda.empty_cache() + print("outputs: ", outputs) + + output_text = tokenizer.decode(outputs[0]) + print(output_text) def check_llama(rank, world_size, port): diff --git a/tests/test_infer/test_llama_infer_1.py b/tests/test_infer/test_llama_infer_1.py deleted file mode 100644 index 0f9271376f3e..000000000000 --- a/tests/test_infer/test_llama_infer_1.py +++ /dev/null @@ -1,87 +0,0 @@ -import os - -import pytest -import torch -import numpy as np - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from transformers import LlamaForCausalLM, LlamaTokenizer -from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.inference.tensor_parallel.engine import TPInferEngine -import torch.distributed as dist - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -TPSIZE = 1 - -def init_to_get_rotary(self, base=10000): - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config,"max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config,"max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() - return - -@parameterize('test_config', [{ - 'tp_size': TPSIZE, -}]) -def run_llama_test(test_config): - - llama_model_path = "/data/scratch/llama-7b-hf" - tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) - tokenizer.pad_token_id = tokenizer.unk_token_id - model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) - init_to_get_rotary(model.model, base=10000) - model = model.half() - model.to(torch.cuda.current_device()) - - text = "Introduce some landmarks in Beijing" - input_ids = tokenizer.encode(text, return_tensors='pt') - # pg_mesh = ProcessGroupMesh(1, 1, test_config["tp_size"]) - - infer_engine = TPInferEngine(model.half(), 4, 12, 8) - shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) - - infer_engine.prepare_with_shard_config(shard_config) - infer_engine.shard_model_by(shardformer) - - generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) - - print("outputs: ", outputs) - - output_text = tokenizer.decode(outputs) - print(output_text) - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_llama_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, TPSIZE) - - -if __name__ == "__main__": - test_llama() From 6d0642129c3ba2d864892fde3678853742262368 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 30 Aug 2023 12:00:21 +0800 Subject: [PATCH 28/28] remove llama_infer_engine.py --- .../tensor_parallel/llama_infer_engine.py | 219 ------------------ 1 file changed, 219 deletions(-) delete mode 100644 colossalai/inference/tensor_parallel/llama_infer_engine.py diff --git a/colossalai/inference/tensor_parallel/llama_infer_engine.py b/colossalai/inference/tensor_parallel/llama_infer_engine.py deleted file mode 100644 index 645cfcdcc38e..000000000000 --- a/colossalai/inference/tensor_parallel/llama_infer_engine.py +++ /dev/null @@ -1,219 +0,0 @@ - -import copy -import torch -from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.inference import MemoryManager -from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -from colossalai.shardformer.policies.auto_policy import get_autopolicy -from transformers import LlamaForCausalLM, LlamaTokenizer -import time - -GIGABYTE = 1024 ** 3 -torch.backends.cudnn.enabled = True -DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 - -def print_device_memory(): - if torch.cuda.is_available(): - current_device = torch.cuda.current_device() - print(f"Currently using GPU: {current_device}") - - # free memory and the total available memory in bytes - global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() - memory_allocated = torch.cuda.memory_allocated() - max_memory_allocated = torch.cuda.max_memory_allocated() - memory_reserved = torch.cuda.memory_reserved() - max_memory_reserved = torch.cuda.max_memory_reserved() - - print( - f" free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" - f" total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" - f" memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" - f" Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" - f" memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" - f" Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" - ) - -def print_perf_stats(latency_set, config, bs, warmup=3): - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 - - print(f"num_layers: {num_layers}, hidden_size: {config.hidden_size}") - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * bs / 1e12)) - print("Avg Throughput: tokens/s: {}".format((1000/(avg * 1000)))) - -def init_to_get_rotary(self, base=10000): - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config,"max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config,"max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() - return - -class TPCacheManagerInferenceEngine: - - def __init__( - self, - input_len: int, - output_len: int, - bs: int, - tp_size: int, - ) -> None: - self.pg_mesh = ProcessGroupMesh(1, 1, tp_size) - self.input_len = input_len - self.output_len = output_len - self.bs = bs - self.tp_size = tp_size - - def init_and_insert_cache_manager(self): - - max_total_token_num = self.bs * (self.input_len + self.output_len) - - head_num = self.model.config.num_attention_heads // self.tp_size - - self.cache_manager = MemoryManager( - max_total_token_num, - torch.float16, - head_num, - self.model.config.hidden_size // self.model.config.num_attention_heads, - self.model.config.num_hidden_layers, - device="cuda", - ) - - setattr(self.model.model, 'cache_manager', self.cache_manager) - - block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") - start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - max_total_token_num = self.bs * (self.input_len) - max_len_in_batch = self.input_len - for i in range(self.bs): - block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") - start_loc[i] = i * self.input_len - seq_len[i] = self.input_len - setattr(self.model.model, 'block_loc', block_loc) - setattr(self.model.model, 'start_loc', start_loc) - setattr(self.model.model, 'seq_len', seq_len) - setattr(self.model.model,'total_token_num', max_total_token_num) - setattr(self.model.model,'max_len_in_batch',max_len_in_batch) - - - def prepare_model(self): - llama_model_path = "/data/scratch/llama-7b-hf" - tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) - tokenizer.pad_token_id = tokenizer.unk_token_id - self.model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) - init_to_get_rotary(self.model.model, base=10000) - self.model = self.model.half() - self.model.to(torch.cuda.current_device()) - - - def build_model(self): - # create new model - self.orgin_model = copy.deepcopy(self.model) - shardconfig = ShardConfig( - tensor_parallel_process_group=self.pg_mesh.get_group_along_axis(TP_DIM), - enable_tensor_parallelism=True, - inference_only=True, - ) - shardformer = ShardFormer(shard_config=shardconfig) - - if self.bs >= 4: - self.use_cache_manager = True - self.init_and_insert_cache_manager() - policy = get_autopolicy(self.model, inference_only=True) - else: - self.use_cache_manager = False - policy = LlamaForCausalLMPolicy() - - self.model, _ = shardformer.optimize(self.model, policy) - - def generate_data(self): - self.input_tokens={"input_ids":torch.randint(1, 1000, (self.bs, self.input_len))} - for t in self.input_tokens: - if torch.is_tensor(self.input_tokens[t]): - self.input_tokens[t] = self.input_tokens[t].to(torch.cuda.current_device()) - print(f" input_tokens[{t}].shape: {self.input_tokens[t].shape}") - self.input_len = self.input_tokens[t].shape[1] - print(f" input_len: {self.input_len}") - - def run_infer(self, test_origin=True): - - if test_origin: - model = self.orgin_model - else: - model = self.model - - generate_kwargs = dict(max_new_tokens=self.output_len, do_sample=False) - - iters = 10 - times = [] - outputs_list = [] - warmup = 3 - for i in range(iters): - torch.cuda.synchronize() - start = time.time() - outputs = model.generate(**self.input_tokens, - **generate_kwargs, early_stopping=False) - outputs_list.append(outputs) - torch.cuda.synchronize() - - end = time.time() - num_tokens_generation = outputs.shape[1] - self.input_len - print(f"num_tokens_generation: {num_tokens_generation}") - print(f"generation time is {(end - start) * 1000} ms") - time_spend = (end-start)/num_tokens_generation - times.append(time_spend) - - if not test_origin and self.use_cache_manager: - model.model.cache_manager.free_all() - block_loc = torch.empty(self.bs, self.input_len + self.output_len, dtype=torch.long, device="cuda") - start_loc = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - seq_len = torch.zeros(self.bs, dtype=torch.int32, device="cuda") - max_len_in_batch = self.input_len - for i in range(self.bs): - block_loc[i, 0:self.input_len] = i * self.input_len + torch.arange(0, self.input_len, dtype=torch.int32, device="cuda") - start_loc[i] = i * self.input_len - seq_len[i] = self.input_len - - setattr(model.model, 'block_loc', block_loc) - setattr(model.model, 'start_loc', start_loc) - setattr(model.model, 'seq_len', seq_len) - setattr(model.model,'max_len_in_batch',max_len_in_batch) - - print_device_memory() - total_time = (end - start) * 1000 - token_latency = total_time/(self.output_len) - print(outputs.shape) - print('per_token_latency',token_latency,'ms') - print_perf_stats(times, self.model.config, self.bs, warmup=warmup) - - return outputs_list - - - \ No newline at end of file