Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/CN/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran
:caption: 部署教程

DeepSeek R1 部署 <tutorial/deepseek_deployment>
FP8 KV 量化与校准 <tutorial/fp8_kv_quantization>
多级缓存部署 <tutorial/multi_level_cache_deployment>
多模态部署 <tutorial/multimodal>
奖励模型部署 <tutorial/reward_model>
Expand Down
5 changes: 4 additions & 1 deletion docs/CN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@ PD 分离模式参数

.. option:: --llm_kv_type

推理后端使用什么类型的数据存储kv cache, 可选值为 "None", "int8kv", "int4kv", "fp8kv"
推理后端使用什么类型的数据存储kv cache, 可选值为 "None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"

- ``fp8kv_sph``: FP8 静态按 head 量化,对应 fa3 后端
- ``fp8kv_spt``: FP8 静态按 tensor 量化,对应 flashinfer 后端

.. option:: --disable_cudagraph

Expand Down
102 changes: 102 additions & 0 deletions docs/CN/source/tutorial/fp8_kv_quantization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
.. _tutorial/fp8_kv_quantization_cn:

FP8 KV 量化与校准指南
======================

本章节介绍 LightLLM 中 FP8 KV 推理的使用方式,包括:

- 使用校准文件进行推理(``fp8kv_sph`` 或 ``fp8kv_spt``)
- FP8 静态按 head 和按 tensor 的量化模式
- 常见报错与排查建议

功能概览
--------

LightLLM 的 FP8 KV 推理需要准备好的校准文件(``kv_cache_calib.json``),
并通过 ``--kv_quant_calibration_config_path`` 加载。
你可以直接使用 ``test/advanced_config/`` 目录下已有的校准文件,
也可以使用 `LightCompress <https://github.com/ModelTC/LightCompress>`_ 工具导出,或使用自有兼容文件。

量化模式与后端对应
------------------

LightLLM 支持两种 FP8 KV 量化模式:

- ``fp8kv_sph``: FP8 静态按 head 量化(Static Per-Head),每个 head 独立 scale,对应 ``fa3`` 后端
- ``fp8kv_spt``: FP8 静态按 tensor 量化(Static Per-Tensor),K/V 各一个标量 scale,对应 ``flashinfer`` 后端

校准文件与量化模式强相关:

- ``fp8kv_sph`` 对应 ``per_head`` 校准文件
- ``fp8kv_spt`` 对应 ``per_tensor`` 校准文件

不建议混用不同模式的校准文件。

使用校准文件启动 FP8 推理
-------------------------

推理模式示例:

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv_sph \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv_spt \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

说明:

- ``fp8kv_sph`` 和 ``fp8kv_spt`` 模式必须提供 ``--kv_quant_calibration_config_path``。
- attention backend 会根据量化模式自动选择,无需手动指定。

.. note::

使用 ``fp8kv_spt`` 模式(FP8 静态按 tensor 量化,使用 flashinfer 后端)时,
必须安装 ``flashinfer-python==0.6.5``。默认安装的版本是 0.6.3,
可能导致运行错误。请使用以下命令安装正确版本:

.. code-block:: console

$ pip install flashinfer-python==0.6.5

校准文件格式
------------

``kv_cache_calib.json`` 主要字段包括:

- ``quant_type``: ``per_head`` 或 ``per_tensor``
- ``num_layers``: 层数
- ``num_head``: 总 head 数
- ``scales_shape``: scale 张量形状
- ``scales``: 实际 scale 数值
- ``qmin`` / ``qmax``: FP8 范围参数

加载校准文件时,会校验模型架构、层数、head 数及量化类型是否匹配。

多卡说明
--------

在多卡(TP)场景下,系统会根据当前 rank 自动切分本地需要的 head 对应 scale。
你仍然只需要提供一份全量 ``kv_cache_calib.json``。

常见问题
--------

1. 启动时报错需要 ``--kv_quant_calibration_config_path``

说明你使用了 ``--llm_kv_type fp8kv_sph`` 或 ``fp8kv_spt`` 但未传入校准文件路径。

2. 报错 ``quant_type not match``

通常是量化模式与校准文件类型不一致。例如拿 ``per_tensor`` 文件去跑 ``fp8kv_sph``。

3. 切换量化模式后效果异常

建议使用与目标量化模式匹配的校准文件,不要跨模式复用不兼容文件。
1 change: 1 addition & 0 deletions docs/EN/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Documentation List
:caption: Deployment Tutorials

DeepSeek R1 Deployment <tutorial/deepseek_deployment>
FP8 KV Quantization and Calibration <tutorial/fp8_kv_quantization>
Multi-Level Cache Deployment <tutorial/multi_level_cache_deployment>
Multimodal Deployment <tutorial/multimodal>
Reward Model Deployment <tutorial/reward_model>
Expand Down
10 changes: 10 additions & 0 deletions docs/EN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,16 @@ Performance Optimization Parameters
* ``flashinfer``: Use FlashInfer backend
* ``triton``: Use Triton backend

.. option:: --llm_kv_type

Set the KV cache data type for inference. Available options:

* ``None``: Use the dtype from model's config.json
* ``int8kv``: INT8 KV quantization
* ``int4kv``: INT4 KV quantization
* ``fp8kv_sph``: FP8 static per-head quantization, uses fa3 backend
* ``fp8kv_spt``: FP8 static per-tensor quantization, uses flashinfer backend

.. option:: --disable_cudagraph

Disable cudagraph in the decoding phase
Expand Down
102 changes: 102 additions & 0 deletions docs/EN/source/tutorial/fp8_kv_quantization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
.. _tutorial/fp8_kv_quantization_en:

FP8 KV Quantization and Calibration Guide
=========================================

This chapter describes FP8 KV inference in LightLLM, including:

- Running inference with calibration data (``fp8kv_sph`` or ``fp8kv_spt``)
- FP8 static per-head and per-tensor quantization modes
- Common errors and troubleshooting

Overview
--------

LightLLM FP8 KV inference requires a prepared calibration file (``kv_cache_calib.json``),
which is loaded by ``--kv_quant_calibration_config_path``.
You can use calibration files provided in ``test/advanced_config/``,
export one with `LightCompress <https://github.com/ModelTC/LightCompress>`_, or use your own compatible file.

Quantization Modes and Backend Mapping
------------------------------------------

LightLLM supports two FP8 KV quantization modes:

- ``fp8kv_sph``: FP8 Static Per-Head quantization, independent scale per head, uses ``fa3`` backend
- ``fp8kv_spt``: FP8 Static Per-Tensor quantization, one scalar for K and one scalar for V, uses ``flashinfer`` backend

Calibration files are mode-dependent:

- ``fp8kv_sph`` corresponds to ``per_head`` calibration files
- ``fp8kv_spt`` corresponds to ``per_tensor`` calibration files

Avoid mixing calibration files across different modes.

Start FP8 Inference with Calibration
------------------------------------

Inference mode example:

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv_sph \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

.. code-block:: console

$ python -m lightllm.server.api_server \
--model_dir /path/to/model \
--llm_kv_type fp8kv_spt \
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json

Notes:

- ``fp8kv_sph`` and ``fp8kv_spt`` require ``--kv_quant_calibration_config_path``.
- The attention backend will be automatically selected based on the quantization mode, no need to manually specify.

.. note::

When using ``fp8kv_spt`` mode (FP8 static per-tensor quantization with flashinfer backend),
you must install ``flashinfer-python==0.6.5``. The default installed version is 0.6.3,
which may cause runtime issues. Install the correct version with:

.. code-block:: console

$ pip install flashinfer-python==0.6.5

Calibration File Schema
-----------------------

Key fields in ``kv_cache_calib.json``:

- ``quant_type``: ``per_head`` or ``per_tensor``
- ``num_layers``: number of layers
- ``num_head``: total number of heads
- ``scales_shape``: shape of the scale tensor
- ``scales``: actual scale values
- ``qmin`` / ``qmax``: FP8 numeric range parameters

At load time, LightLLM validates architecture, layer count, head count, and quantization type.

Multi-GPU Note
--------------

In multi-GPU (TP) setups, LightLLM slices the global scales to local rank heads automatically.
You only need to provide one full ``kv_cache_calib.json`` file.

Common Issues
-------------

1. Error says ``--kv_quant_calibration_config_path`` is required

You are using ``--llm_kv_type fp8kv_sph`` or ``fp8kv_spt`` without a calibration file path.

2. ``quant_type not match`` error

Usually caused by quantization mode/file mismatch (for example, using a ``per_tensor`` file with ``fp8kv_sph``).

3. Abnormal quality after mode switch

Use a calibration file that matches the target quantization mode instead of reusing an incompatible file.
6 changes: 6 additions & 0 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
# "fa3": Fp8Fa3AttBackend,
# "flashinfer": Fp8FlashInferAttBackend,
},
"fp8kv_sph": {
"fa3": Fp8Fa3AttBackend,
},
"fp8kv_spt": {
"flashinfer": Fp8FlashInferAttBackend,
},
}

mla_data_type_to_backend = {
Expand Down
58 changes: 16 additions & 42 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,9 @@ def init_state(self):
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
)
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)


def prefill_att(
self,
Expand All @@ -89,19 +74,21 @@ def _fp8_prefill_att(
) -> torch.Tensor:
self.backend: Fp8Fa3AttBackend = self.backend # for typing

q_head_num = q.shape[1]
q_head_dim = q.shape[2]
k_head_num = k.shape[1]
q, q_scale = q_per_head_fp8_quant(
q,
q.reshape(q.shape[0], k_head_num, -1),
self.infer_state.b_seq_len,
self.cu_seqlens_q,
self.mid_token_batch_ids,
token_batch_ids=self.mid_token_batch_ids,
)
k_head_num = k.shape[1]
k_head_dim = k.shape[2]
cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
o = flash_attn_with_kvcache(
q=q,
q=q.reshape(-1, q_head_num, q_head_dim),
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
Expand Down Expand Up @@ -141,24 +128,9 @@ def init_state(self):
head_num = mem_manager.head_num

# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
if offline_scales is not None
else torch.ones(
(mem_manager.layer_num, batch_size, head_num),
dtype=torch.float32,
device=device,
)
)
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)

return

def copy_for_decode_cuda_graph(self, new_state: "Fp8Fa3DecodeAttState"):
Expand Down Expand Up @@ -200,9 +172,11 @@ def _fp8_decode_att(
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)

q_head_num = q.shape[1]
q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
if scaled_fp8_quant is None:
raise ImportError("scaled_fp8_quant is unavailable. Please install vllm to enable FP8 decode attention.")
q, q_scale = scaled_fp8_quant(q.reshape(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
o = flash_attn_with_kvcache(
q=q.view(-1, q_head_num, k_head_dim),
q=q.reshape(-1, q_head_num, k_head_dim),
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
Expand Down
Loading
Loading