Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ Completion api is used for single sequence request, like answer a question or co
- POST '/chat':
Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.
#### chat-template
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example temlate bellow. Both str or file style chat template are supported.
Followed `transformers`, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the [HuggingFace-blog](https://huggingface.co/blog/chat-templates) for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.
### Usage
#### Args for customizing your server
The configuration for api server contains both serving interface and engine backend.
Expand Down
33 changes: 32 additions & 1 deletion colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.generation import GenerationConfig

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.utils import can_use_flash_attn2

GibiByte = 1024**3

Expand Down Expand Up @@ -169,7 +170,8 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
use_spec_dec (bool): Indicate whether to use speculative decoding, defaults to False.
max_n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
Expand Down Expand Up @@ -214,6 +216,7 @@ class InferenceConfig(RPC_PARAM):
ignore_eos: bool = False

# speculative decoding configs
use_spec_dec: bool = False
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False

Expand Down Expand Up @@ -311,6 +314,16 @@ def to_generation_config(self, model_config) -> GenerationConfig:

return GenerationConfig.from_dict(meta_config)

def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig":
use_flash_attn = can_use_flash_attn2(self.dtype)
model_inference_config = ModelShardInferenceConfig(
dtype=self.dtype,
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
)
return model_inference_config

def to_rpc_param(self) -> dict:
kwargs = {
"dtype": str(self.dtype).split(".")[-1],
Expand Down Expand Up @@ -362,3 +375,21 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Set the attributes from the parsed arguments.
inference_config = cls(**inference_config_args)
return inference_config


@dataclass
class ModelShardInferenceConfig:
"""
Configurations used during init of module for inference modeling.

Args:
dtype (torch.dtype): The data type for weights and activations.
use_cuda_kernel (bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_spec_dec (bool): Indicate whether to use speculative decoding.
use_flash_attn (bool): Indicate whether to use flash attention.
"""

dtype: torch.dtype = None
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
38 changes: 26 additions & 12 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
Expand Down Expand Up @@ -72,8 +72,9 @@ def __init__(

self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()

self.init_model(model_or_path, model_policy)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)

self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
Expand All @@ -97,21 +98,29 @@ def __init__(
self.capture_model(self.k_cache, self.v_cache)

# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.use_spec_dec = self.inference_config.use_spec_dec

self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens

self._verify_args()

def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight

Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""

if isinstance(model_or_path, str):
Expand All @@ -124,6 +133,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P
# the model load process in the future.
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")

except Exception as e:
Expand Down Expand Up @@ -167,6 +177,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
Expand All @@ -187,7 +198,7 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[P
# assert if_has_index_file, "the model path is invalid"
# cpt_io.load_model(self.model, model_index_file)

free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
Expand Down Expand Up @@ -287,6 +298,7 @@ def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
Expand All @@ -312,6 +324,7 @@ def _shardformer(
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
Expand Down Expand Up @@ -348,6 +361,7 @@ def enable_spec_dec(
engine.clear_spec_dec()
```
"""

if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
Expand Down Expand Up @@ -517,19 +531,19 @@ def generate(
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.

Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
request_ids (List[int], optional): The request ID. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.

Returns:
List[str]: Inference result returned by one generation.
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""

gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
Expand Down
Empty file.
168 changes: 168 additions & 0 deletions colossalai/inference/modeling/backends/attention_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
from flash_attn import flash_attn_varlen_func

from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention


@dataclass
class AttentionMetaData:
query_states: torch.Tensor
key_states: torch.Tensor
value_states: torch.Tensor
k_cache: torch.Tensor
v_cache: torch.Tensor
block_tables: torch.Tensor
block_size: int
kv_seq_len: int = None
sequence_lengths: torch.Tensor = None
cu_seqlens: torch.Tensor = None
sm_scale: int = None
alibi_slopes: torch.Tensor = None
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False


class AttentionBackend(ABC):
@abstractmethod
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
raise NotImplementedError

@abstractmethod
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
raise NotImplementedError


class CudaAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""

def __init__(self, use_flash_attn: bool):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)
attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
cu_seqlens_q=attn_metadata.cu_seqlens,
cu_seqlens_k=attn_metadata.cu_seqlens,
max_seqlen_q=attn_metadata.kv_seq_len,
max_seqlen_k=attn_metadata.kv_seq_len,
dropout_p=0.0,
softmax_scale=attn_metadata.sm_scale,
causal=True,
alibi_slopes=attn_metadata.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
else:
attn_output = context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True, # use new k-cache layout
)
return attn_output

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor


class TritonAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
"""

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
)

def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
return flash_decoding_attention(
q=attn_metadata.query_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
kv_seq_len=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
max_seq_len_in_batch=attn_metadata.kv_seq_len,
output=attn_metadata.output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=attn_metadata.alibi_slopes,
sm_scale=attn_metadata.sm_scale,
kv_group_num=kwargs.get("num_key_value_groups", 1),
q_len=kwargs.get("q_len", 1),
)


def get_attention_backend(
model_shard_infer_config: ModelShardInferenceConfig,
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
for attention module calculation only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
the Triton backend will use a new k cache layout for Triton kernels.
"""
# Currently only triton kernels support speculative decoding
if model_shard_infer_config.use_spec_dec:
return TritonAttentionBackend()

if model_shard_infer_config.use_cuda_kernel:
return CudaAttentionBackend(model_shard_infer_config.use_flash_attn)

return TritonAttentionBackend()
Loading