diff --git a/colossalai/inference/continous_batching/layers/attention.py b/colossalai/inference/continous_batching/layers/attention.py new file mode 100644 index 000000000000..e3ddd8bdc30f --- /dev/null +++ b/colossalai/inference/continous_batching/layers/attention.py @@ -0,0 +1,440 @@ +"""Multi-head attention.""" +from typing import List, Optional + +import torch +import torch.nn as nn +from vllm import attention_ops, cache_ops, pos_encoding_ops +from vllm.model_executor.input_metadata import InputMetadata +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias + +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] + + +class PagedAttention(nn.Module): + # pylint: disable=line-too-long + """GPT-style multi-head PagedAttention. + + This class takes flattened 1D query, key, and value tensors as input. The + input 1D tensors can either contain prompt tokens or generation tokens, in + addition to paddings. + + If the input tensors contain prompt tokens, the layout is as follows: + + |<---------------------- num_valid_tokens ---------------------->| + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| + + Otherwise, the layout is as follows: + + |<------------------ num_valid_tokens ------------------->| + |<------- num_generation_tokens (M) ------->| + |<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + + The prompts might have different lengths, while the generation tokens always + have length 1. The paddings are appended to make the input length a multiple + of 8, which is desirable for Tensor Cores. + + The class does the following: + 1. Perform multi_query_kv_attention for the prompts. This operation does + not use the KV cache. + 2. Wait for the cache operations (e.g., swap, copy) to finish. The cache + operations are issued by the cache engine before executing the forward + pass of the model, and they are executed asynchronously. + 3. Reshape and store the input key and value tensors in the KV cache. + 4. Perform single_query_cached_kv_attention for the generation tokens. + This operation reads the previous key and value tensors from the KV + cache. + 5. Output a flattened 1D tensor. + """ + + def __init__(self, num_heads: int, head_size: int, scale: float, num_kv_heads: Optional[int] = None) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.head_mapping = torch.repeat_interleave(torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), + self.num_queries_per_kv) + + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + + def set_attn_bias( + self, + input_metadata: InputMetadata, + dtype: torch.dtype, + ) -> None: + del dtype # Unused. + if input_metadata.attn_bias: + # Already set by a previous layer. + return + prompt_lens = input_metadata.prompt_lens + attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) + input_metadata.attn_bias.append(attn_bias) + + def multi_query_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Normal attention for the prompt tokens. + + Args: + output: shape = [num_prompt_tokens, num_heads, head_size] + query: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] + input_metadata: metadata for paged attention. + """ + + if self.num_kv_heads != self.num_heads: + # Project the key and value tensors to the desired number of heads. + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=1) + + # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. + out = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=input_metadata.attn_bias[0], + p=0.0, + scale=self.scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output.copy_(out.squeeze(0)) + return output + + def single_query_cached_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + ) -> None: + """PagedAttention for the generation tokens. + + Args: + output: shape = [num_generation_tokens, num_heads, head_size] + query: shape = [num_generation_tokens, num_heads, head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for paged attention. + """ + block_size = value_cache.shape[3] + attention_ops.single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + None, # alibi_slopes + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + kv_cache_stream: torch.cuda.Stream = torch.cuda.current_stream(), + ) -> torch.Tensor: + """PagedAttention forward pass. + + NOTE: The query, key, and value tensors must be sliced from a qkv + tensor of shape [num_tokens, 3 * num_heads * head_size]. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for paged attention. + cache_event: event to wait for the cache operations to finish. + kv_cache_stream: stream used to kv cache operation. + + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + # Pre-allocate the output tensor. + output = torch.empty_like(query) + + # Reshape the keys and values and store them in the cache. + # When key_cache and value_cache are not provided, the new key + # and value vectors will not be cached. + #TODO The usage of multiple streams may have issues and may need to be fixed later + kv_cache_event = torch.cuda.Event() + origin_stream = torch.cuda.current_stream() + torch.cuda.stream(kv_cache_stream) + num_valid_tokens = input_metadata.num_valid_tokens + num_prompt_tokens = input_metadata.num_prompt_tokens + if (num_valid_tokens > 0 and key_cache is not None and value_cache is not None): + # The stride is 3 because the key and value are sliced from qkv. + cache_ops.reshape_and_cache( + key[:num_valid_tokens], + value[:num_valid_tokens], + key_cache, + value_cache, + input_metadata.slot_mapping, + ) + kv_cache_event.record(stream=kv_cache_stream) + torch.cuda.stream(origin_stream) + + # Compute the attention op for prompts. + if num_prompt_tokens > 0: + # Prompt run. + assert input_metadata.num_generation_tokens == 0 + self.set_attn_bias(input_metadata, dtype=query.dtype) + self.multi_query_kv_attention( + output[:num_prompt_tokens], + query[:num_prompt_tokens], + key[:num_prompt_tokens], + value[:num_prompt_tokens], + input_metadata, + ) + + # Wait until the cache op is done. + if cache_event is not None: + cache_event.wait() + kv_cache_event.wait() + + if input_metadata.num_generation_tokens > 0: + # Decoding run. + assert input_metadata.num_prompt_tokens == 0 + assert key_cache is not None and value_cache is not None, ( + "key_cache and value_cache must be provided when " + "generating tokens.") + # Compute the attention op for generation tokens. + self.single_query_cached_kv_attention(output[num_prompt_tokens:num_valid_tokens], + query[num_prompt_tokens:num_valid_tokens], key_cache, value_cache, + input_metadata) + + # Reshape the output tensor. + # NOTE(woosuk): The output tensor may include paddings. + return output.view(-1, self.num_heads * self.head_size) + + +class PagedAttentionWithRoPE(PagedAttention): + """PagedAttention with rotary embedding.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + rotary_dim: int, + max_position: int = 8192, + base: int = 10000, + num_kv_heads: Optional[int] = None, + is_neox_style: bool = True, + kv_cache_stream: torch.cuda.Stream = torch.cuda.current_stream(), + ) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads) + self.is_neox_style = is_neox_style + + # Create the cos and sin cache. + inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim)) + t = torch.arange(max_position, dtype=torch.float, device="cuda") + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + self.kv_cache_stream = kv_cache_stream + + # FIXME(woosuk): This assumes that we configure the default dtype when + # initializing the model. + # TODO(woosuk): Make it more robust. + torch_dtype = torch.get_default_dtype() + cache = cache.to(torch_dtype) + # Embedding size: [max_position, rotary_dim] + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + """ PagedAttention forward pass with rotary embedding. + + Args: + positions: shape = [num_tokens] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for paged attention. + cache_event: event to wait for the cache operations to finish. + + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + # Apply rotary embedding to the query and key before passing them + # to the attention op. + pos_encoding_ops.rotary_embedding_neox( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + ) + return super().forward(query, key, value, key_cache, value_cache, input_metadata, cache_event, + self.kv_cache_stream) + + +class PagedAttentionWithALiBi(PagedAttention): + """PagedAttention with ALiBi attention bias.""" + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + slopes: List[float], + num_kv_heads: Optional[int] = None) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads) + assert len(slopes) == num_heads + + slopes = torch.tensor(slopes, dtype=torch.float32) + self.register_buffer("alibi_slopes", slopes, persistent=False) + + def set_attn_bias(self, input_metadata: InputMetadata, dtype: torch.dtype) -> None: + if input_metadata.attn_bias: + # Already set by a previous layer. + return + # Generates ALiBi mask for each prompt. + for prompt_len in input_metadata.prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # Note(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + bias = bias.to(self.alibi_slopes.device) + + # When using custom attention bias, xformers requires the bias to + # be sliced from a tensor whose length is a multiple of 8. + padded_len = (prompt_len + 7) // 8 * 8 + bias = torch.empty( + 1, # batch_size + self.num_heads, + prompt_len, + padded_len, + device=self.alibi_slopes.device, + dtype=dtype, + )[:, :, :, :prompt_len].copy_(bias) + bias.mul_(self.alibi_slopes[:, None, None]) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + input_metadata.attn_bias.append(attn_bias) + + def multi_query_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Attention with ALiBi bias for the prompt tokens. + + Args: + output: shape = [num_prompt_tokens, num_heads, head_size] + query: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] + input_metadata: metadata for paged attention. + """ + if self.num_kv_heads != self.num_heads: + # Project the key and value tensors to the desired number of heads. + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=1) + + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + start = 0 + for i, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=input_metadata.attn_bias[i], + p=0.0, + scale=self.scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.squeeze(0)) + start += prompt_len + return output + + def single_query_cached_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + ) -> None: + """PagedAttention with ALiBi bias for the generation tokens. + + Args: + output: shape = [num_generation_tokens, num_heads, head_size] + query: shape = [num_generation_tokens, num_heads, head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for paged attention. + """ + block_size = value_cache.shape[3] + attention_ops.single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + self.alibi_slopes, + ) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a5a55702ade0..2fb09f190282 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -2,31 +2,61 @@ import torch import torch.nn as nn -from transformers import BloomForCausalLM, LlamaForCausalLM +import warnings +from transformers import ( + AutoConfig, + AutoTokenizer, + BloomForCausalLM, + LlamaForCausalLM, + LlamaTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) from transformers.generation import GenerationConfig from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.tokenization_utils_base import BatchEncoding +try: + from vllm import LLM + from vllm.outputs import RequestOutput + from vllm.sampling_params import SamplingParams + + USE_CONTINOUS_BATCHING = True + +except ImportError: + warnings.warn("vllm is not installed, continuous batching will not be supported.") + USE_CONTINOUS_BATCHING = False + RequestOutput = None from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.auto_policy import get_autopolicy from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager +from .utils import init_to_get_rotary, replace_page_attention DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] +_supported_models = { + "LlamaForCausalLM": LlamaForCausalLM, + "LLaMAForCausalLM": LlamaForCausalLM, + "LlamaModel": LlamaForCausalLM, + "BloomForCausalLM": BloomForCausalLM +} class TPInferEngine: """Engine class for tensor parallel inference. Args: - model (Module): original model, e.g. huggingface CausalLM + model (str, nn.Module): The name, path or instance of a HuggingFace Transformers model. shard_config (ShardConfig): The config for sharding original model max_batch_size (int): maximum batch size max_input_len (int): maximum input length of sequence max_output_len (int): maximum output length of output tokens + trust_remote_code (bool): Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tokenizer (str, PreTrainedTokenizer, PreTrainedTokenizerFast): The name, path or instance of a HuggingFace Transformers tokenizer. + use_continous_batching (bool): whether to use continous_batching dtype (torch.dtype): datatype used to init KV cache space device (str): device the KV cache of engine to be initialized on @@ -40,13 +70,48 @@ class TPInferEngine: """ def __init__(self, - model: nn.Module, - shard_config: ShardConfig, - max_batch_size: int, - max_input_len: int, - max_output_len: int, + model: Union[str, nn.Module], + shard_config: ShardConfig = None, + max_batch_size: int = 8, + max_input_len: int = 16, + max_output_len: int = 8, + trust_remote_code: bool = False, + use_continous_batching: bool = False, + tokenizer: Optional[Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast]] = None, dtype: torch.dtype = torch.float16, device: str = 'cuda') -> None: + + if tokenizer is None: + print("model: ", model) + assert isinstance(model, str), \ + "when tokenizer is None, model must be string." + tokenizer = model + + self.tp_size = 1 + self.use_continous_batching = use_continous_batching and USE_CONTINOUS_BATCHING + + if shard_config != None and shard_config.enable_tensor_parallelism: + self.tp_size = shard_config.tensor_parallel_size + + if self.use_continous_batching: + assert isinstance(model, str) and isinstance(tokenizer, str), \ + "when using continous_batching, model and tokenizer must be string." + self.llm_engine = LLM(model=model, + tokenizer=tokenizer, + trust_remote_code=trust_remote_code, + tensor_parallel_size=self.tp_size) + #TODO We will replace multiple models' attention forward with shardformer in vllm to achieve multi-stream optimization later. + # kv_cache_stream = torch.cuda.Stream() + # self.model = replace_page_attention(self.llm_engine.llm_engine.workers[0].model, kv_cache_stream) + self.model = self.llm_engine.llm_engine.workers[0].model + else: + self.model, self.tokenizer = self._get_model_and_tokenizer(model, tokenizer, trust_remote_code) + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.model = self.model.half() + self.model = self.model.to(device) + self.shard_config = shard_config + self.max_batch_size = max_batch_size self.max_input_len = max_input_len self.max_output_len = max_output_len @@ -59,17 +124,47 @@ def __init__(self, self.dtype = dtype - self.head_dim = model.config.hidden_size // model.config.num_attention_heads - self.head_num = model.config.num_attention_heads - self.layer_num = model.config.num_hidden_layers + self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + self.head_num = self.model.config.num_attention_heads + self.layer_num = self.model.config.num_hidden_layers - self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None - - self.shard_config = shard_config - self.model = None - # optimize the original model by sharding with ShardFormer - self._optimize_model(model=model.to(device)) + + self._optimize_model() + + def _get_model_and_tokenizer(self, model: str, tokenizer: str, trust_remote_code: bool) -> nn.Module: + + if isinstance(model, nn.Module) and isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + return model, tokenizer + + supported_model = model + + try: + config = AutoConfig.from_pretrained(model, trust_remote_code=trust_remote_code) + except ValueError as e: + if (not trust_remote_code and "requires you to execute the configuration file" in str(e)): + err_msg = ("Failed to load the model config. If the model is a custom " + "model not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _supported_models: + if isinstance(tokenizer, str): + if arch == "LlamaForCausalLM" or arch == "LLaMAForCausalLM" or arch == "LlamaModel": + tokenizer = LlamaTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code) + else: + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code) + if isinstance(model, str): + supported_model = _supported_models[arch].from_pretrained(model, + pad_token_id=tokenizer.eos_token_id) + + return supported_model, tokenizer + raise ValueError(f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {self.supported_models}") def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" @@ -78,16 +173,17 @@ def _init_manager(self) -> None: self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) - def _optimize_model(self, model: nn.Module) -> None: + def _optimize_model(self) -> None: """ Optimize the original model by sharding with ShardFormer. In further generation, use the sharded model instead of original model. """ # NOTE we will change to use an inference config later with additional attrs we want - assert self.shard_config.inference_only is True - shardformer = ShardFormer(shard_config=self.shard_config) - self._prepare_with_shard_config(shard_config=self.shard_config) - self._shard_model_by(shardformer, model) + if not self.use_continous_batching: + assert self.shard_config.inference_only is True + shardformer = ShardFormer(shard_config=self.shard_config) + self._prepare_with_shard_config(shard_config=self.shard_config) + self._shard_model_by(shardformer) def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: """ Prepare the engine with a given ShardConfig. @@ -96,7 +192,6 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) shard_config (ShardConfig): shard config given to specify settings of the engine. If not provided, a default ShardConfig with tp size 1 will be created. """ - self.tp_size = 1 if shard_config is None: shard_config = ShardConfig( tensor_parallel_process_group=None, @@ -111,37 +206,76 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) else: shard_config.inference_only = True shard_config.pipeline_stage_manager = None - if shard_config.enable_tensor_parallelism: - self.tp_size = shard_config.tensor_parallel_size self._init_manager() return shard_config - def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: + def _shard_model_by(self, shardformer: ShardFormer) -> None: """ Shard original model by the given ShardFormer and store the sharded model. """ assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" - model_name = model.__class__.__name__ + model_name = self.model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." - policy = get_autopolicy(model, inference_only=True) - self.model, _ = shardformer.optimize(model, policy) + if model_name == "LlamaForCausalLM": + init_to_get_rotary(self.model.model, base=10000) + policy = get_autopolicy(self.model, inference_only=True) + self.model, _ = shardformer.optimize(self.model, policy) self.model = self.model.cuda() @property def supported_models(self) -> List[str]: - return _supported_models + return list(_supported_models.keys()) - def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor: + def generate(self, + prompts: Optional[Union[str, List[str]]] = None, + prompt_token_ids: Optional[Union[BatchEncoding, dict, list, torch.Tensor]] = None, + **generate_kwargs) -> Union[List[RequestOutput], torch.Tensor]: """Generate token sequence. Args: - input_tokens: could be one of the following types + prompts: A list of prompts to generate completions for. + prompt_token_ids: could be one of the following types 1. BatchEncoding or dict (e.g. tokenizer batch_encode) 2. list of input token ids (e.g. appended result of tokenizer encode) 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + If None, we use the tokenizer to convert the prompts to token IDs. Returns: - torch.Tensor: The returned sequence is given inputs + generated_tokens. + Union[List[RequestOutput], torch.Tensor]: The returned sequence is given inputs + generated_tokens. """ + + if prompts is None and prompt_token_ids is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + if prompts is not None and prompt_token_ids is not None: + if isinstance(prompt_token_ids, (BatchEncoding, dict)): + prompt_token_len = len(prompt_token_ids['input_ids']) + elif isinstance(prompt_token_ids, torch.Tensor): + prompt_token_len = prompt_token_ids.shape[0] + else: + prompt_token_len = len(prompt_token_ids) + if len(prompts) != prompt_token_len: + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + + if self.use_continous_batching: + if not isinstance(prompt_token_ids, list): + raise TypeError(f"prompt_token_ids type must be list, when using continous batching.") + sampling_params = SamplingParams(temperature=0.0, max_tokens=self.max_output_len) + return self.llm_engine.generate(prompts=prompts, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + + if prompt_token_ids is None: + input_tokens = self.tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True) + else: + input_tokens = prompt_token_ids + + if isinstance(input_tokens, list): + input_tokens = torch.Tensor(input_tokens).cuda() + if isinstance(input_tokens, torch.Tensor): input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) for t in input_tokens: diff --git a/colossalai/inference/tensor_parallel/utils.py b/colossalai/inference/tensor_parallel/utils.py new file mode 100644 index 000000000000..e8a54928cb74 --- /dev/null +++ b/colossalai/inference/tensor_parallel/utils.py @@ -0,0 +1,52 @@ +import types +import warnings + +import torch +try: + from vllm.model_executor.models.llama import LlamaAttention + VLLM_INSTALLED = True +except ImportError: + warnings.warn("vllm is not installed, PageAttention will not be replaced.") + VLLM_INSTALLED = False + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def replace_page_attention(model, kv_cache_stream): + if VLLM_INSTALLED: + + from colossalai.inference.continous_batching.layers.attention import PagedAttentionWithRoPE + + layers = model.model.layers + for i in range(len(layers)): + layer = layers[i] + if isinstance(layer.self_attn, LlamaAttention) is True: + attn = PagedAttentionWithRoPE(layer.self_attn.num_heads, + layer.self_attn.head_dim, + layer.self_attn.scaling, + rotary_dim=layer.self_attn.head_dim, + num_kv_heads=layer.self_attn.num_kv_heads, + kv_cache_stream=kv_cache_stream) + setattr(layer.self_attn, 'attn', attn) + + return model diff --git a/examples/inference/bench_llama_continous_batching.py b/examples/inference/bench_llama_continous_batching.py new file mode 100644 index 000000000000..a4fcb405ff05 --- /dev/null +++ b/examples/inference/bench_llama_continous_batching.py @@ -0,0 +1,155 @@ +import argparse +import os +import time + +import numpy as np +import torch +from torch.profiler import ProfilerActivity, profile, record_function +from transformers import LlamaForCausalLM, LlamaTokenizer +from vllm import LLM +from vllm.sampling_params import SamplingParams + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.inference.tensor_parallel.utils import init_to_get_rotary +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + + +def run_llama_test(args): + llama_model_path = args.path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + tokenizer = args.tokenizer + test_mode = args.test_mode + test_continous_batching = args.test_continous_batching + + if (tokenizer == None): + tokenizer = llama_model_path + + tmp_tokenizer = LlamaTokenizer.from_pretrained(tokenizer) + tmp_tokenizer.pad_token_id = tmp_tokenizer.unk_token_id + tmp_model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tmp_tokenizer.eos_token_id) + + model_config = tmp_model.config + + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + if test_mode == "colossalai" and not test_continous_batching: + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + else: + input_tokens = [np.random.randint(1, 1000, [max_input_len]).tolist() for _ in range(max_batch_size)] + + iters = 10 + times = [] + + if test_mode == "colossalai": + if test_continous_batching: + print("test_continous_batching: ", test_continous_batching) + infer_engine = TPInferEngine( + model=llama_model_path, + max_output_len=max_output_len, + use_continous_batching=test_continous_batching, + tokenizer=tokenizer, + ) + else: + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True) + infer_engine = TPInferEngine( + model=llama_model_path, + shard_config=shard_config, + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + use_continous_batching=test_continous_batching, + tokenizer=tokenizer, + ) + elif test_mode == "vllm": + infer_engine = LLM(model=llama_model_path, tokenizer=tokenizer) + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + if test_mode == "colossalai": + outputs = infer_engine.generate(prompt_token_ids=input_tokens, **generate_kwargs) + elif test_mode == "vllm": + sampling_params = SamplingParams(temperature=0.0, max_tokens=max_output_len) + outputs = infer_engine.generate(prompt_token_ids=input_tokens, sampling_params=sampling_params) + torch.cuda.synchronize() + end = time.time() + if test_mode == "colossalai" and not test_continous_batching: + out_len = outputs.shape[1] + else: + # out_len = len(outputs[0].outputs[0].token_ids) + out_len = 1024 + 128 + print("generation time {} s".format(str(end - start))) + times.append((end - start) / (out_len - max_input_len)) + + print("outputs, ", len(outputs)) + print_perf_stats(times, model_config, max_batch_size) + + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + # with record_function("model_inference"): + # torch.cuda.synchronize() + # if test_mode == "colossalai": + # outputs = infer_engine.generate(prompt_token_ids=input_tokens, **generate_kwargs) + # elif test_mode == "vllm": + # sampling_params = SamplingParams(temperature=0.0, max_tokens=max_output_len) + # outputs = infer_engine.generate(prompt_token_ids=input_tokens, sampling_params=sampling_params) + # torch.cuda.synchronize() + # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-t', '--tokenizer', type=str, default=None, help='Tokenizer path') + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument('--test_mode', type=str, default="colossalai", help='test colossalai or vllm') + parser.add_argument('--test_continous_batching', + type=bool, + default=False, + help='whether to test continous_batching') + + args = parser.parse_args() + + test_llama(args) diff --git a/examples/inference/test_continous_batching.py b/examples/inference/test_continous_batching.py new file mode 100644 index 000000000000..b4f22edd50f7 --- /dev/null +++ b/examples/inference/test_continous_batching.py @@ -0,0 +1,62 @@ +import os + +import pytest +import torch +from packaging import version +from vllm import LLM + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 +MAX_OUTPUT_LEN = 100 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') + + +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config): + model = 'decapoda-research/llama-7b-hf' + tokenizer = 'hf-internal-testing/llama-tokenizer' + + test_prompts = [ + "A robot may not injure a human being", + "To be or not to be,", + "What is the meaning of life?", + "It is only with the heart that one can see rightly", + "Can you introduce Beijing?", + ] + + infer_engine = TPInferEngine( + model=model, + max_output_len=MAX_OUTPUT_LEN, + use_continous_batching=True, + tokenizer=tokenizer, + ) + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate(test_prompts, **generate_kwargs) + + print("outputs: ", outputs) + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 8ecabf69ecf3..ade23d260b07 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -3,6 +3,7 @@ import pytest import torch from packaging import version +from transformers import PreTrainedTokenizer import colossalai from colossalai.inference.tensor_parallel import TPInferEngine @@ -19,6 +20,10 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +class DummyTokenizer(PreTrainedTokenizer): + pass + + @parameterize('test_config', [{ 'tp_size': TP_SIZE, }]) @@ -27,15 +32,21 @@ def run(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom_for_causal_lm') for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): orig_model = model_fn() + tokenizer = DummyTokenizer() orig_model = orig_model.half() data = data_gen_fn() shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) - infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine = TPInferEngine(orig_model, + shard_config, + MAX_BATCH_SIZE, + MAX_INPUT_LEN, + MAX_OUTPUT_LEN, + tokenizer=tokenizer) generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(data, **generate_kwargs) + outputs = infer_engine.generate(prompt_token_ids=data, **generate_kwargs) assert outputs is not None diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index cc3cdd2b501b..aa0ff38f8f89 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from packaging import version -from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizer from transformers.tokenization_utils_base import BatchEncoding import colossalai @@ -22,19 +22,29 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +class DummyTokenizer(PreTrainedTokenizer): + pass + + @parameterize('test_config', [{ 'tp_size': TP_SIZE, }]) def run(test_config): model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) model = BloomForCausalLM(model_config) + tokenizer = DummyTokenizer() model = model.half() model.to(torch.cuda.current_device()) # 1. check TPInferEngine init and model optimization shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) - infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine = TPInferEngine(model, + shard_config, + MAX_BATCH_SIZE, + MAX_INPUT_LEN, + MAX_OUTPUT_LEN, + tokenizer=tokenizer) assert infer_engine.cache_manager is not None assert infer_engine.tp_size == TP_SIZE @@ -71,7 +81,7 @@ def run(test_config): # 3. check optimized model generate input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) generate_kwargs = dict(do_sample=False) - infer_engine.generate(input_ids, **generate_kwargs) + infer_engine.generate(prompt_token_ids=input_ids, **generate_kwargs) torch.cuda.empty_cache() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index aa8874ea4cb0..66dfb45f20c0 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -4,6 +4,7 @@ import pytest import torch from packaging import version +from transformers import PreTrainedTokenizer import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine @@ -44,6 +45,10 @@ def init_to_get_rotary(self, base=10000): return +class DummyTokenizer(PreTrainedTokenizer): + pass + + @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) @@ -52,16 +57,22 @@ def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm') for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): orig_model = model_fn() + tokenizer = DummyTokenizer() init_to_get_rotary(orig_model.model, base=10000) orig_model = orig_model.half() data = data_gen_fn() shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) - infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine = TPInferEngine(orig_model, + shard_config, + BATCH_SIZE, + MAX_INPUT_LEN, + MAX_OUTPUT_LEN, + tokenizer=tokenizer) generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(data, **generate_kwargs) + outputs = infer_engine.generate(prompt_token_ids=data, **generate_kwargs) assert outputs is not None