diff --git a/docs/_static/img/pymllm-arch.png b/docs/_static/img/pymllm-arch.png new file mode 100644 index 000000000..37c48b2a0 Binary files /dev/null and b/docs/_static/img/pymllm-arch.png differ diff --git a/docs/index.rst b/docs/index.rst index 3db7d58e2..22b60ca96 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -357,6 +357,11 @@ Documents cache/index +.. toctree:: + :maxdepth: 2 + + pymllm_runtime/index + .. toctree:: :maxdepth: 2 diff --git a/docs/pymllm_runtime/developer_guide.rst b/docs/pymllm_runtime/developer_guide.rst new file mode 100644 index 000000000..47b528659 --- /dev/null +++ b/docs/pymllm_runtime/developer_guide.rst @@ -0,0 +1,220 @@ +pymllm Developer Guide +====================== + +总览 +---------------------------------------- + +本文档面向希望为 ``pymllm`` 增加模型、量化格式、kernel 或性能优化的开发者。当前代码处在 +快速演进阶段,推荐遵循“小步验证、边界清晰、先单测后服务级验证”的工作方式。 + +开发环境建议 +---------------------------------------- + +推荐使用 editable install,便于修改 Python 代码后直接验证: + +.. code-block:: bash + + cd + SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e . + python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation + +最小检查: + +.. code-block:: bash + + python3 - <<'PY' + import pymllm + import mllm_kernel + print("ok") + PY + +``mllm-kernel`` 的 JIT 编译产物会写入 ``~/.cache/mllm_kernel``。正常修改后重新运行 +会触发相应 kernel 的加载或编译;只有在验证首次编译行为、排查失败缓存、或更换 CUTLASS +等外部头文件来源时,才需要清理对应缓存: + +.. code-block:: bash + + rm -rf ~/.cache/mllm_kernel/ + +新增模型 +---------------------------------------- + +新增模型时,优先复用现有 ``pymllm.layers`` 和 ``pymllm.executor`` 约定,而不是把 +HuggingFace 模型直接包进服务。 + +推荐步骤: + +1. 新增 ``pymllm/models/.py``。 +2. 在 ``pymllm/models/__init__.py`` 注册 architecture 字符串。 +3. 实现模型类,保持 ``forward(input_ids, positions, forward_batch)`` 风格。 +4. 所有 linear layer 都接受 ``quant_method``。 +5. 实现 ``load_weights``,处理 checkpoint key、stacked projection 和 tied embedding。 +6. 增加最小单测。 +7. 最后做服务级 smoke test。 + +最小测试建议: + +.. code-block:: bash + + pytest pymllm/tests/test__model_registry.py -q + pytest pymllm/tests/test__weight_loading.py -q + pytest pymllm/tests/test__forward_timing.py -q + +新增量化 scheme +---------------------------------------- + +新增量化路径时,不建议在模型文件里写格式判断。推荐保持以下分层: + +.. code-block:: text + + QuantizationConfig + parses checkpoint config + decides whether a layer is quantized + + LinearMethod + owns linear layer lifecycle + + Scheme + owns checkpoint-facing params + owns post-load layout conversion + owns kernel apply path + +``create_weights`` 应注册 checkpoint-facing 参数名。``process_weights_after_loading`` 应作为 +checkpoint layout 到 runtime kernel layout 的唯一转换边界。``apply`` 中只做 forward 必需的 +runtime 计算,不应重复做权重 repack。 + +新增量化路径至少需要覆盖: + +- config 解析测试。 +- ``ignore`` / prefix 匹配测试。 +- 参数注册 shape/dtype 测试。 +- post-load layout 转换测试。 +- forward correctness 或 smoke test。 + +新增 CUDA JIT kernel +---------------------------------------- + +若 kernel 适合走 ``mllm-kernel`` 的 TVM-FFI JIT 路径,推荐结构如下: + +.. code-block:: text + + mllm-kernel/mllm_kernel/cuda/csrc//.cuh + mllm-kernel/mllm_kernel/cuda/jit/.py + mllm-kernel/tests/test_.py + mllm-kernel/benchmarks/bench_.py + +Python wrapper 应负责: + +- 校验输入 shape、dtype、device。 +- 分配输出 tensor。 +- 调用 ``@jit`` 包装后的 compiled module。 +- 暴露稳定、简洁的 Python API。 + +CUDA/C++ source 应尽量只表达 kernel 语义,不混入 checkpoint 配置解析或模型层逻辑。 + +如果 kernel 依赖 CUTLASS 等重模板库,可以先做编译 spike。确认 Jetson 目标设备上的编译时间、 +缓存路径、include 来源和内存占用后,再决定使用 TVM-FFI JIT、torch extension JIT 或 AOT 构建。 + +服务级验证 +---------------------------------------- + +服务级 smoke test 应覆盖: + +- ``/v1/models`` 可返回。 +- 文本 ``/v1/chat/completions`` 可完成。 +- 图文模型能处理容器内图片绝对路径。 +- streaming 与 non-streaming 至少各测一次。 +- 中止请求或客户端断连不会泄漏 running request。 + +示例: + +.. code-block:: bash + + curl -s --noproxy '*' http://127.0.0.1:30000/v1/models ; echo + + curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "只回复 ok"}], + "max_tokens": 8, + "temperature": 0.0, + "stream": false + }' ; echo + +性能验证 +---------------------------------------- + +性能数据需要固定口径,否则不同记录之间很难比较。建议记录: + +- commit hash。 +- JetPack / L4T 版本。 +- GPU 型号和 compute capability。 +- PyTorch、Triton、FlashInfer、CUDA 版本。 +- 模型路径和量化格式。 +- 启动命令。 +- prompt token 数、max tokens、temperature。 +- 是否启用 radix cache、CUDA Graph、shared queue。 +- 是否包含首次 JIT 编译。 + +对服务级请求,建议丢弃第一次 warmup 结果,记录第 2/3 次请求的 prefill/decode 统计。 +对 kernel microbench,建议单独记录 warmup、重复次数、输入 shape 和 dtype。 + +常见问题定位 +---------------------------------------- + +启动失败 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +优先确认: + +- ``pymllm`` 和 ``mllm_kernel`` 是否来自预期源码目录或安装版本。 +- ``model_path`` 和 ``tokenizer_path`` 是否在容器内可见。 +- ``transformers`` 是否能读取目标 ``config.json``。 +- CUDA 是否可用,``torch.cuda.get_device_capability()`` 是否符合量化 kernel 要求。 + +W8A8 编译失败 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +优先确认: + +- ``CUTLASS_HOME`` 是否设置正确。 +- ``flashinfer`` 是否包含 bundled CUTLASS。 +- ``~/.cache/mllm_kernel/cutlass_int8_scaled_mm/`` 是否存在旧的失败缓存。 +- 当前 GPU 是否为 SM80-SM89。 + +请求卡住或 CPU 占用高 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +优先确认: + +- scheduler 是否启用了 idle sleep。 +- tokenizer / scheduler / detokenizer 子进程是否全部存活。 +- 是否有请求已经断连但未 abort。 +- ``max_total_tokens`` 是否过小导致 KV allocation 反复失败和 eviction。 + +输出异常 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +优先确认: + +- tokenizer chat template 是否符合目标模型。 +- EOS token 是否从 config、generation_config 或 tokenizer 中正确解析。 +- 量化模型的 ``ignore`` 是否覆盖视觉分支、embedding、norm 和 lm_head 等不应量化模块。 +- ``process_weights_after_loading`` 是否已执行。 + +贡献建议 +---------------------------------------- + +开发时尽量保持以下边界: + +- 服务协议变化放在 ``pymllm/server``。 +- 请求/响应结构放在 ``pymllm/engine/io_struct.py``。 +- 调度策略放在 ``pymllm/orchestrator/scheduler_process.py``。 +- GPU 资源和 forward 逻辑放在 ``pymllm/executor``。 +- 模型结构放在 ``pymllm/models``。 +- 基础层放在 ``pymllm/layers``。 +- 量化格式放在 ``pymllm/quantization``。 +- 自定义 kernel 放在 ``mllm-kernel``。 + +这样可以避免把一次模型适配写成跨层补丁,也方便后续把同一能力复用到更多模型和设备。 diff --git a/docs/pymllm_runtime/index.rst b/docs/pymllm_runtime/index.rst new file mode 100644 index 000000000..b6c9bdf2e --- /dev/null +++ b/docs/pymllm_runtime/index.rst @@ -0,0 +1,12 @@ +pymllm Runtime +============== + +.. toctree:: + :maxdepth: 2 + + setup_and_usage + runtime_design + models_and_quantization + kernels_and_acceleration + developer_guide + diff --git a/docs/pymllm_runtime/kernels_and_acceleration.rst b/docs/pymllm_runtime/kernels_and_acceleration.rst new file mode 100644 index 000000000..d5d09c30c --- /dev/null +++ b/docs/pymllm_runtime/kernels_and_acceleration.rst @@ -0,0 +1,203 @@ +pymllm Kernels and Acceleration +=============================== + +总览 +---------------------------------------- + +``pymllm`` 的性能路径由多类加速组件共同组成: + +- FlashInfer:paged KV cache attention。 +- CUDA Graph:decode 阶段减少 CPU launch overhead。 +- Triton:W8A8 per-token activation quantization。 +- CUTLASS:W8A8 INT8 Tensor Core GEMM。 +- ``mllm-kernel``:基于 TVM-FFI / torch extension 的 JIT kernel 工具包。 + +这些组件不是彼此替代关系,而是在不同层次承担职责。attention backend 解决 KV cache +attention;CUDA Graph 解决重复 decode step 的 launch overhead;Triton 和 CUTLASS 解决量化 +linear 的核心计算;``mllm-kernel`` 为项目内自定义 CUDA/C++ kernel 提供封装、缓存和工具。 + +mllm-kernel +---------------------------------------- + +``mllm-kernel`` 是 mllm 项目中的高性能 kernel 包。当前 Python 侧主要包含: + +- ``mllm_kernel.cuda.jit``:CUDA JIT kernel wrapper。 +- ``mllm_kernel.cpu.jit``:CPU JIT kernel wrapper。 +- ``mllm_kernel.jit_utils``:JIT 编译、缓存、注册表和工具函数。 + +CUDA JIT kernel 的典型结构是: + +.. code-block:: text + + Python wrapper + -> @jit(...) + -> include CUDA/C++ source + -> export TVM-FFI typed function + -> compile on first use + -> reuse cached shared library + +默认 JIT 缓存目录为: + +.. code-block:: text + + ~/.cache/mllm_kernel/ + +``mllm-kernel`` 的 JIT 路径与 SGLang 的 ``jit_kernel`` 设计关系更直接:二者都强调轻量 +JIT、运行时选择模板实例、避免大型 AOT torch extension 带来的长编译周期。与此同时,SGLang +的 ``sgl-kernel`` AOT kernel 仍然是重要参考,尤其适合对照量化 GEMM 的语义和性能。 + +TVM-FFI JIT 路径 +---------------------------------------- + +``mllm_kernel.jit_utils.jit`` decorator 会将 Python 函数包装成一个按需编译的 kernel 调用。 +它负责: + +- 根据 tensor device 推断 CPU/CUDA 目标。 +- 将 Python 参数转换为 C++ template 参数。 +- 拼接 C++/CUDA source 和 export wrapper。 +- 调用 TVM-FFI 编译并加载 shared library。 +- 将编译结果缓存到 ``~/.cache/mllm_kernel``。 + +这种方式适合小而明确的自定义 kernel,例如: + +- ``create_kv_indices``:构造 FlashInfer KV index metadata。 +- ``store_cache``:将 K/V 写入 KVPool。 +- ``gptq_marlin_repack``:Marlin weight layout 转换。 +- ``gptq_marlin_gemm``:W4A16 Marlin GEMM。 + +W8A8 CUTLASS kernel 当前使用 ``torch.utils.cpp_extension.load`` 编译。这是因为 CUTLASS +模板和 include 体系较重,当前以稳定通过 Jetson SM87 编译为优先。 + +FlashInfer Attention +---------------------------------------- + +``pymllm.layers.attention.flashinfer_backend.FlashInferAttnBackend`` 封装 FlashInfer 的 paged +KV cache attention。它负责: + +- 为 prefill 和 decode 准备 ``kv_indptr``、``kv_indices``、``kv_last_page_len`` 等 metadata。 +- 管理全局 workspace buffer。 +- 根据是否存在 sliding window 选择 wrapper dispatch。 +- 在 decode 中根据 GQA group size 和 KV dtype 决定是否使用 tensor core 路径。 +- 为 CUDA Graph capture / replay 提供专用 metadata 初始化接口。 + +prefill 和 decode 使用不同 wrapper: + +.. code-block:: text + + prefill / extend + BatchPrefillWithPagedKVCacheWrapper + BatchPrefillWithRaggedKVCacheWrapper + + decode + BatchDecodeWithPagedKVCacheWrapper + +attention backend 只负责 attention 计算和 metadata,不负责请求调度和 KV slot 生命周期。KV slot +的分配、释放和 prefix cache 命中由 scheduler / model runner 侧完成。 + +CUDA Graph +---------------------------------------- + +``pymllm.executor.cuda_graph_runner.CudaGraphRunner`` 用于 decode step 的 CUDA Graph capture +和 replay。它的目标是减少小 batch decode 中 CPU launch overhead。 + +初始化阶段会按一组离散 batch size 捕获 graph: + +.. code-block:: text + + [1, 2, 4, 8, 12, 16, 24, 32, ...] + +每个 captured graph 复用预分配输入 buffer: + +- ``input_ids`` +- ``req_pool_indices`` +- ``seq_lens`` +- ``out_cache_loc`` +- ``positions`` +- ``mrope_position_deltas`` + +replay 时,真实 batch 会被 padding 到最近的 captured batch size。attention backend 会走专用 +``init_forward_metadata_replay_cuda_graph`` 路径,避免使用普通动态 metadata 初始化。 + +CUDA Graph 只覆盖 decode 主路径。调试模型、调试 attention metadata 或定位 shape 问题时,可以 +使用 ``--server.disable_cuda_graph`` 暂时关闭。 + +W4A16 Marlin +---------------------------------------- + +W4A16 路径复用 Marlin kernel。checkpoint 权重先以 ``weight_packed`` 和 ``weight_scale`` +加载,然后在 post-load 阶段转换为 Marlin runtime layout。 + +关键 kernel: + +- ``mllm_kernel.cuda.jit.gptq_marlin_repack`` +- ``mllm_kernel.cuda.jit.gptq_marlin`` + +执行约束包括: + +- SM80+ +- output partition 可被 64 整除 +- input partition 可被 128 整除 +- group size 当前主路径为 32 + +这种路径适合 AWQ / W4A16 类权重量化模型,activation 保持 FP16/BF16。 + +W8A8 Triton + CUTLASS +---------------------------------------- + +W8A8 路径包含两个核心 kernel: + +1. ``pymllm.quantization.kernels.int8_activation_triton.per_token_quant_int8`` +2. ``mllm_kernel.cuda.jit.int8_scaled_mm_cutlass.int8_scaled_mm`` + +运行时链路: + +.. code-block:: text + + [M, K] fp16/bf16 activation + -> Triton per-token absmax + round + int8 cast + -> [M, K] int8 + [M, 1] fp32 scale + -> CUTLASS int8 GEMM with per-row/per-col scales + -> [M, N] fp16/bf16 output + +CUTLASS kernel 要求 ``mat_b`` 为 ``[K, N]`` column-major,因此 W8A8 scheme 会在 +``process_weights_after_loading`` 中把 checkpoint 的 ``[N, K]`` INT8 weight 转成对应布局。 + +当前 CUTLASS include 查找顺序为: + +1. ``CUTLASS_HOME/include`` +2. ``flashinfer`` bundled CUTLASS +3. 系统 include 目录 + +如果找不到 CUTLASS 头文件,W8A8 初始化会失败。生产环境建议在镜像中固定 CUTLASS 来源,避免 +不同节点使用不同版本头文件。 + +GDN decode kernel +---------------------------------------- + +Qwen3.5 等 hybrid 模型可能包含 GDN / linear attention 层。``pymllm`` 为这类模型保留了: + +- ``pymllm.layers.attention.gdn_backend`` +- ``pymllm.layers.attention.hybrid_backend`` +- ``mllm_kernel.cuda.jit.gdn_decode`` +- ``MambaRadixCache`` / GDN state cache 相关结构 + +当前文档重点覆盖 Qwen3 / Qwen3-VL 主路径。GDN 相关路径仍应以具体模型和测试结果为准。 + +调试与观测 +---------------------------------------- + +常用检查命令: + +.. code-block:: bash + + python3 -m mllm_kernel show-env + python3 -m mllm_kernel show-config + python3 -m pymllm show-config + +当首次运行时间异常长时,应区分: + +- 模型权重加载时间。 +- FlashInfer / CUDA context 初始化时间。 +- CUTLASS JIT 编译时间。 +- CUDA Graph capture 时间。 +- 实际 prefill/decode 时间。 diff --git a/docs/pymllm_runtime/models_and_quantization.rst b/docs/pymllm_runtime/models_and_quantization.rst new file mode 100644 index 000000000..e7d92dd1e --- /dev/null +++ b/docs/pymllm_runtime/models_and_quantization.rst @@ -0,0 +1,226 @@ +pymllm Models and Quantization +============================== + +总览 +---------------------------------------- + +``pymllm`` 的模型实现遵循 PyTorch ``nn.Module`` 风格,并通过 HuggingFace +``config.architectures`` 字段选择模型类。当前重点支持 Qwen3 family: + +- ``Qwen3ForCausalLM``:文本模型,例如 Qwen3-0.6B。 +- ``Qwen3VLForConditionalGeneration``:图文模型,例如 Qwen3-VL-2B-Instruct。 +- ``Qwen3_5ForCausalLM`` 和 ``Qwen3_5ForConditionalGeneration``:hybrid attention / GDN + 相关模型骨架。 + +量化系统以 linear layer 为核心,使用插件式 ``LinearMethodBase`` 生命周期: + +.. code-block:: text + + QuantizationConfig + -> get_quant_method(layer, prefix) + -> LinearMethodBase + -> create_weights() + -> process_weights_after_loading() + -> apply() + +模型注册 +---------------------------------------- + +模型注册表位于 ``pymllm/models/__init__.py``。运行时会根据 HuggingFace config 中的 +architecture 字符串懒加载模型类: + +.. code-block:: text + + "Qwen3ForCausalLM" + -> pymllm.models.qwen3.Qwen3ForCausalLM + + "Qwen3VLForConditionalGeneration" + -> pymllm.models.qwen3_vl.Qwen3VLForConditionalGeneration + + "Qwen3_5ForCausalLM" + -> pymllm.models.qwen3_5.Qwen3_5ForCausalLM + +这种注册方式让服务启动阶段只导入目标模型所需的代码,避免在命令行工具或轻量检查中提前加载 +大量 PyTorch/CUDA 依赖。 + +Qwen3 文本模型 +---------------------------------------- + +``Qwen3ForCausalLM`` 使用标准 decoder-only 结构: + +- token embedding +- 多层 decoder block +- Q/K Norm +- 1D RoPE +- MLP +- final norm +- lm head + +它复用 ``RadixAttention``、``RMSNorm``、``MLP``、``ColumnParallelLinear`` 和 +``RowParallelLinear`` 等基础层。与 Qwen3-VL 文本分支相比,Qwen3 文本模型使用 1D RoPE, +不需要多模态 M-RoPE 的三维 position 逻辑。 + +Qwen3-VL 图文模型 +---------------------------------------- + +``Qwen3VLForConditionalGeneration`` 在文本 decoder 外增加视觉输入处理和 M-RoPE 位置编码。 +在一次图文请求中: + +1. tokenizer / processor 处理 messages 和图片路径。 +2. ``TokenizerProcess`` 生成 token ids 和多模态输入 tensor。 +3. 多模态 tensor 通过 ZMQ 或 shared queue 送到 scheduler。 +4. 模型 forward 中先处理视觉侧输入,再进入语言模型 prefill/decode。 +5. decode 阶段使用每个请求保存的 ``mrope_position_delta`` 修正位置。 + +当前 W8A8 量化主要覆盖语言 decoder 的线性层;视觉 encoder、embedding、LayerNorm 和 +``lm_head`` 保持全精度。 + +量化配置解析 +---------------------------------------- + +服务启动时,``ModelRunner`` 会解析量化配置。优先级为: + +1. 命令行 ``--quantization.method``。 +2. checkpoint 目录中的量化配置文件。 +3. ``config.json`` 中的 ``quantization_config`` 字段。 + +``compressed-tensors`` 路径使用 ``pymllm.quantization.methods.compressed_tensors``, +当前支持两类签名: + +.. list-table:: + :header-rows: 1 + + * - 签名 + - 格式 + - 权重 + - 激活 + - 执行路径 + * - W4A16 + - ``pack-quantized`` + - 4-bit packed weight + - FP16/BF16 activation + - Marlin WNA16 GEMM + * - W8A8 + - ``int-quantized`` + - INT8 static weight + - INT8 dynamic per-token activation + - Triton quant + CUTLASS INT8 GEMM + +``ignore`` 字段会让匹配前缀的模块跳过量化。例如 Qwen3-VL 的视觉分支通常保留为全精度。 + +W4A16 / AWQ Marlin 路径 +---------------------------------------- + +W4A16 路径面向 ``compressed-tensors`` 的 ``pack-quantized`` checkpoint。当前支持的 +约束是: + +- ``format == "pack-quantized"`` +- ``weights.num_bits == 4`` +- ``weights.group_size == 32`` +- ``weights.symmetric == true`` +- ``actorder == null`` +- GPU capability 不低于 SM80 + +权重加载和执行分为三个阶段: + +.. code-block:: text + + checkpoint tensors + weight_packed / weight_scale / weight_shape + │ + ▼ + process_weights_after_loading() + gptq_marlin_repack() + marlin_permute_scales() + create runtime-only zero/g_idx placeholders + │ + ▼ + apply() + gptq_marlin_gemm() + +``create_weights`` 注册与 checkpoint 对齐的参数名,保证 safetensors 加载逻辑可以按名称写入。 +``process_weights_after_loading`` 是 checkpoint layout 到 runtime kernel layout 的边界,repack +不应放在通用权重加载器或每次 forward 中。 + +W8A8 INT8 路径 +---------------------------------------- + +W8A8 路径面向 ``compressed-tensors`` 的 ``int-quantized`` checkpoint。当前支持的约束是: + +- ``format == "int-quantized"`` +- ``weights.num_bits == 8`` +- ``weights.type == "int"`` +- ``weights.strategy == "channel"`` +- ``weights.dynamic == false`` +- ``weights.symmetric == true`` +- ``input_activations.num_bits == 8`` +- ``input_activations.type == "int"`` +- ``input_activations.strategy == "token"`` +- ``input_activations.dynamic == true`` +- ``input_activations.symmetric == true`` +- W8A8 CUTLASS 路径当前支持 Ampere / SM8x GPU(SM80-SM89)。已验证目标为 + Jetson Orin SM87;Hopper / SM90 暂不包含在当前支持范围内。 + +执行链路如下: + +.. code-block:: text + + x(fp16/bf16) + │ + ▼ + per_token_quant_int8() [Triton] + │ + ├── x_q(int8) + └── x_scale(float32) + │ + ▼ + int8_scaled_mm() [CUTLASS] + │ + └── output(fp16/bf16) + +checkpoint 中的 INT8 权重通常是 ``[N, K]`` row-major。``process_weights_after_loading`` +会将其转换为 ``[K, N]`` column-major 视图并整理 ``weight_scale``,以满足 CUTLASS kernel +接口约定。 + +LinearMethod 生命周期 +---------------------------------------- + +所有 linear layer 都持有一个 ``quant_method``: + +- 未量化时使用 ``UnquantizedLinearMethod``,注册普通 ``weight`` 并调用 ``F.linear``。 +- 量化时由 ``QuantizationConfig.get_quant_method(layer, prefix)`` 返回具体方法。 + +典型生命周期: + +1. 模型构造时,linear layer 调用 ``quant_method.create_weights`` 注册参数。 +2. ``model.load_weights`` 根据参数名和 ``weight_loader`` 写入 checkpoint tensor。 +3. 所有权重加载完成后,``ModelRunner`` 遍历模块并调用 + ``process_weights_after_loading``。 +4. forward 时,linear layer 委托 ``quant_method.apply`` 执行。 + +这个边界使新增量化方法时不需要改动模型主逻辑,只需要实现新的 config 和 scheme。 + +新增模型的建议流程 +---------------------------------------- + +新增模型时建议遵循以下顺序: + +1. 在 ``pymllm/models/`` 中新增模型文件。 +2. 在 ``pymllm/models/__init__.py`` 注册 HuggingFace architecture 字符串。 +3. 实现最小 forward 接口:``forward(input_ids, positions, forward_batch)``。 +4. 复用现有基础层,并确保 linear layer 接受 ``quant_method``。 +5. 实现 ``load_weights``,处理 checkpoint 前缀、stacked projection 和 tied embedding。 +6. 增加 registry、weight loading、forward timing 的单元测试。 +7. 最后再做服务级 smoke test。 + +新增量化方法的建议流程 +---------------------------------------- + +新增量化方法时建议保持三层结构: + +1. ``QuantizationConfig``:解析 checkpoint 配置,决定某个 layer 是否量化。 +2. ``LinearMethod``:承接 layer 生命周期。 +3. ``Scheme``:处理具体格式的参数注册、post-load 转换和 kernel apply。 + +不要把 checkpoint 格式判断写入模型类,也不要把 runtime repack 隐藏在通用 +``weight_loader`` 中。这样可以保证模型结构、权重格式和 kernel layout 三者的边界清晰。 diff --git a/docs/pymllm_runtime/runtime_design.rst b/docs/pymllm_runtime/runtime_design.rst new file mode 100644 index 000000000..309ea7a21 --- /dev/null +++ b/docs/pymllm_runtime/runtime_design.rst @@ -0,0 +1,204 @@ +pymllm Runtime Design +===================== + +总览 +---------------------------------------- + +``pymllm`` 是 mllm 的 Python serving runtime。它不是传统意义上的 mllm C++ +Backend,而是一套围绕 PyTorch/CUDA 生态构建的在线推理服务运行时。当前实现面向 +Jetson Orin 等边缘 GPU 设备,重点支持 Qwen3、Qwen3-VL 和 Qwen3.5 系列模型。 + +它的设计参考了 SGLang serving runtime 的核心分层,但进行了明显收缩:当前主路径以 +单机单 GPU 为目标,优先保证在 Jetson 上可运行、可调试、可扩展,而不是覆盖大规模 +分布式 serving 的全部复杂度。 + +.. figure:: ../_static/img/pymllm-arch.png + :width: 100% + :alt: pymllm runtime architecture + :align: center + + Figure 1: pymllm runtime architecture. + +整体分层 +---------------------------------------- + +从开发者视角看,``pymllm`` 可以分为五层: + +1. **服务入口层**:FastAPI HTTP server,提供 OpenAI-compatible API 和原生 + ``/generate`` API。 +2. **配置层**:``ServerConfig``、``ModelConfig``、``QuantizationConfig`` 统一解析 + 模型路径、dtype、调度参数、缓存参数、量化参数和加速开关。 +3. **控制面**:``Engine`` 启动 tokenizer、scheduler、detokenizer 子进程,并在主进程中 + 维护 request/response 状态。 +4. **数据面**:scheduler 持有 GPU-owning ``ModelRunnerProcess``,负责 batch 构造、 + KV cache 分配、prefix cache 命中、forward 和 sampling。 +5. **加速层**:FlashInfer、CUDA Graph、Triton、CUTLASS 和 ``mllm-kernel`` 提供 attention、 + quantization、GEMM 和缓存写入等高频算子。 + +进程拓扑 +---------------------------------------- + +``Engine`` 在启动时创建三个子进程,并在主进程中保留 request/response 管理逻辑: + +.. code-block:: text + + Main Process + ├── FastAPI Server + ├── Engine + └── RequestResponseProcess + │ + │ ZMQ + ▼ + TokenizerProcess + │ + │ ZMQ or shared queue + ▼ + SchedulerProcess + └── ModelRunnerProcess (in-process, owns GPU resources) + │ + │ ZMQ + ▼ + DetokenizerProcess + │ + │ ZMQ + ▼ + RequestResponseProcess + +这个拓扑的核心取舍是:GPU 资源由 scheduler 进程内的 ``ModelRunnerProcess`` 直接持有。 +这样 scheduler 可以在同一进程中完成调度、KV cache 资源释放、prefix cache 更新和模型 +forward,避免再引入 model worker 进程之间的 GPU 资源同步。 + +请求生命周期 +---------------------------------------- + +一次 chat completion 请求的典型路径如下: + +1. HTTP server 接收请求并转换为 ``GenerateReqInput``。 +2. ``RequestResponseProcess`` 为请求分配 request id,并把请求送入 tokenizer。 +3. ``TokenizerProcess`` 调用 tokenizer / processor,生成 ``TokenizedGenerateReqInput``。 +4. ``SchedulerProcess`` 接收 tokenized request,创建 ``Req``,放入等待队列。 +5. scheduler 根据 token budget、running request 数量和 prefill/decode 状态构造 + ``ScheduleBatch``。 +6. ``ModelRunnerProcess`` 为 batch 分配 request slot 和 KV slot,执行 prefix matching。 +7. ``ModelRunner`` 构造 ``ForwardBatch``,初始化 attention backend metadata,调用模型 + ``forward``,并对 logits 做 sampling。 +8. scheduler 更新每个 ``Req`` 的输出 token、finished reason 和 timing 字段。 +9. ``DetokenizerProcess`` 将 token id 转回文本。 +10. HTTP server 以普通 JSON 或 SSE streaming 形式返回结果。 + +控制面:Engine 与配置 +---------------------------------------- + +``pymllm.configs.server_config.ServerConfig`` 是服务运行时的主配置对象。它覆盖: + +- 模型和 tokenizer:``model_path``、``tokenizer_path``、``load_format``、``dtype``。 +- HTTP server:``host``、``port``、``api_key``、``served_model_name``。 +- 调度与内存:``max_running_requests``、``max_total_tokens``、``max_prefill_tokens``、 + ``mem_fraction_static``。 +- 加速后端:``attention_backend``、``gdn_decode_backend``、``disable_cuda_graph``、 + ``enable_torch_compile``。 +- IPC 与多模态传输:``enable_shared_queue``、``tensor_transport_mode``、 + ``cuda_ipc_pool_size_mb``。 +- 观测与调试:``log_level``、``decode_log_interval``。 + +``Engine`` 启动前会加载 HuggingFace config,解析 EOS token、默认输出长度和 dtype,并确保 +model/tokenizer 路径可用。启动后,``Engine`` 会监控子进程健康状态;任一核心子进程异常退出, +服务会被标记为 unhealthy。 + +调度器 +---------------------------------------- + +``SchedulerProcess`` 是 pymllm 的中心调度组件。它负责: + +- 接收 tokenized requests。 +- 将输入请求转换为内部 ``Req`` 状态。 +- 根据 prefill/decode 状态构造 ``ScheduleBatch``。 +- 控制 ``max_running_requests``、``max_total_tokens``、``max_prefill_tokens`` 等资源约束。 +- 在请求结束或中止时释放 request slot 和 KV slot。 +- 将 decode token 发送给 detokenizer。 + +当前调度策略以 FCFS 和单 GPU 资源约束为主。``max_prefill_tokens`` 用于限制一轮调度 +可接纳的 prefill token 数;长 prompt 的运行时 chunked prefill 切分仍待后续接入。 + +ModelRunner +---------------------------------------- + +``ModelRunner`` 是真正执行模型 forward 的组件。它在初始化阶段完成: + +1. 设置 CUDA device 和默认 dtype。 +2. 加载模型类和 safetensors 权重。 +3. 解析模型 metadata,例如 layer 数、head 数、head dim、context length。 +4. 初始化 request-to-token pool、token-to-KV pool 和 KV allocator。 +5. 初始化 attention backend。 +6. 预热 cuBLAS。 +7. 按配置捕获 decode CUDA Graph。 + +forward 阶段分为 extend 和 decode 两类: + +- **extend / prefill**:处理 prompt token,写入 KV cache,并返回每个请求最后一个 token 的 + logits。 +- **decode**:每个请求生成一个新 token,复用已有 KV cache 和 attention metadata。 + +KV cache 与 prefix cache +---------------------------------------- + +``pymllm.mem_cache.memory_pool`` 中的 KV 管理采用三层结构: + +.. code-block:: text + + ReqToTokenPool + maps (request slot, position) -> kv index + + TokenToKVPoolAllocator + manages free integer KV slots + + KVPool + stores per-layer K/V tensors on GPU + +``TokenToKVPoolAllocator`` 使用 free-list 管理 KV slot,并通过批量释放接口降低大量请求结束或 +prefix cache eviction 时的开销。``KVPool`` 在条件满足时会调用 ``mllm-kernel`` 的 +``store_cache`` JIT kernel 写入 K/V;否则回退到 PyTorch indexing。 + +Prefix cache 当前有三种实现: + +- ``RadixCache``:标准 radix-tree prefix cache。 +- ``ChunkCache``:关闭 radix cache 时使用的简单缓存路径。 +- ``MambaRadixCache``:为包含 GDN / Mamba-like 状态的 hybrid 模型预留的状态缓存路径。 + +当启用 ``RadixCache`` 时,extend batch 会先执行 prefix matching。命中的 prefix token 不再 +重复计算,但对应 radix tree 节点会被 lock,直到请求结束或资源释放时再 unlock。 + +IPC 与多模态数据传输 +---------------------------------------- + +普通控制消息通过 ZMQ 传输。多模态请求中的大 tensor 可以走 shared queue fast path, +由 ``enable_shared_queue`` 和 ``tensor_transport_mode`` 控制。 + +``tensor_transport_mode`` 支持三种模式: + +.. list-table:: + :header-rows: 1 + + * - 模式 + - 行为 + - 适用场景 + * - ``default`` + - GPU tensor 先拷到 CPU,再放入 POSIX shared memory。 + - 最稳妥,调试优先。 + * - ``cuda_ipc`` + - GPU tensor 通过 CUDA IPC handle 跨进程共享。 + - 避免 GPU->CPU 拷贝,但长服务中可能有 PyTorch IPC 生命周期问题。 + * - ``cuda_ipc_pool`` + - 使用预分配 GPU workspace,发送方回收 chunk。 + - 面向生产服务的推荐 GPU tensor 传输方式。 + +与 mllm C++ Backend 的关系 +---------------------------------------- + +``pymllm`` 和 ``cpu_backend``、``qnn_backend``、``ascend_backend`` 的层级不同: + +- C++ Backend 接入的是 mllm C++ 的 Tensor、Op、Module、Dispatcher 和设备 allocator。 +- ``pymllm`` 接入的是 Python/PyTorch serving pipeline,主要服务于在线推理、模型加载、 + KV cache、调度和 CUDA kernel 集成。 +- ``mllm-kernel`` 是两者可以共享思想的低层 kernel 工具包,但当前 ``pymllm`` 更直接依赖 + 其中的 Python JIT CUDA kernel。 diff --git a/docs/pymllm_runtime/setup_and_usage.rst b/docs/pymllm_runtime/setup_and_usage.rst new file mode 100644 index 000000000..3097bbbbf --- /dev/null +++ b/docs/pymllm_runtime/setup_and_usage.rst @@ -0,0 +1,359 @@ +pymllm Setup and Usage +====================== + +总览 +---------------------------------------- + +``pymllm`` 是 mllm 面向 Python 生态的推理服务运行时,主要面向 NVIDIA Jetson +Orin 系列边缘 GPU 设备,例如 Jetson Orin NX 与 Jetson AGX Orin。它覆盖 +Qwen3 / Qwen3-VL 的 BF16、W4A16 和 W8A8 推理路径,并提供 OpenAI-compatible +HTTP API。 + +环境要求 +---------------------------------------- + +当前推荐基于 `jetson-containers `_ +提供的 Jetson PyTorch/CUDA 基础镜像进行开发。这样可以避免在 Jetson 上手工处理 +PyTorch、CUDA、cuDNN、Python ABI 等基础依赖。 + +已验证环境如下: + +.. list-table:: + :header-rows: 1 + + * - 组件 + - 版本或说明 + * - JetPack / Jetson Linux + - JetPack ``6.2.1`` / Jetson Linux ``36.4.4`` (L4T ``R36.4.4``) + * - Python + - ``3.10.12`` + * - PyTorch + - ``2.4.0`` + * - torchvision + - ``0.19.0a0+48b1edf`` + * - transformers + - ``5.3.0`` + * - safetensors + - ``0.7.0`` + * - flashinfer + - ``0.6.7`` + * - Triton Language + - ``triton==3.6.0`` aarch64 wheel + * - CUDA + - ``12.6`` + * - GPU + - Jetson Orin NX,SM87 + +安装依赖 +---------------------------------------- + +在 Jetson 容器中克隆仓库后,进入仓库根目录安装 ``pymllm`` 和 ``mllm-kernel``: + +.. code-block:: bash + + cd + SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e . + python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation + +``transformers`` 可按项目需要自行安装。``triton`` 和 ``flashinfer`` 可以从 +Jetson AI Lab 的 wheel 源安装,也可以从官方 PyPI 或对应上游项目安装: + +.. code-block:: bash + + # 方式一:从 Jetson AI Lab 安装 Jetson wheel。 + python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ triton flashinfer + + # 方式二:从官方 PyPI 固定 Triton,再单独安装 FlashInfer。 + python3 -m pip install --index-url https://pypi.org/simple triton==3.6.0 + python3 -m pip install --extra-index-url https://pypi.jetson-ai-lab.io/ flashinfer + +在 Jetson / aarch64 上,Triton wheel 的可用性会受到 wheel 来源、CUDA 路径和 +``ptxas`` / ``cuda.h`` 查找路径影响。Jetson AI Lab 源提供面向 JetPack 6 / +CUDA 12.6 的 Triton wheel;在已验证环境中,官方 PyPI 的 ``triton==3.6.0`` +manylinux aarch64 wheel 更接近开箱即用。若使用 Jetson AI Lab wheel 遇到 +``ptxas`` 或 CUDA 头文件查找问题,可显式设置 ``TRITON_PTXAS_PATH`` 和 +``CPATH`` 后重试。无论选择哪个来源,都建议用最小 Triton kernel 或 +``per_token_quant_int8`` 做 smoke test。 + +最小导入检查: + +.. code-block:: bash + + python3 - <<'PY' + import pymllm + import mllm_kernel + + print("pymllm import ok") + print("mllm_kernel import ok") + PY + +CUTLASS 头文件 +---------------------------------------- + +W8A8 的高性能 GEMM 路径依赖 CUTLASS 头文件。当前查找顺序为: + +1. ``CUTLASS_HOME/include`` +2. ``flashinfer`` 内置的 ``data/cutlass/include`` +3. ``/usr/local/include``、``/usr/include``、``/usr/local/cuda/include`` + +首次调用 CUTLASS W8A8 kernel 会触发 JIT 编译,编译产物会复用: + +.. code-block:: text + + ~/.cache/mllm_kernel/cutlass_int8_scaled_mm/ + +如果需要重新验证首次编译行为,可以删除该目录后再次运行。 + +启动服务 +---------------------------------------- + +``pymllm`` 的服务入口是 ``pymllm.server.launch``。服务启动后会提供 +``/health``、``/v1/models``、``/v1/completions``、``/v1/chat/completions``、 +``/generate`` 等接口。 + +W4A16 / W8A8 量化模型 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``compressed-tensors`` 量化模型使用同一个启动入口。运行时会根据模型 +``config.json`` 中的量化配置识别 W4A16 或 W8A8 路径。 + +.. code-block:: bash + + cd + + python3 -m pymllm.server.launch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype float16 \ + --quantization.method compressed-tensors \ + --server.host 0.0.0.0 \ + --server.port 30000 \ + --server.attention_backend auto \ + --server.gdn_decode_backend pytorch \ + --server.mem_fraction_static 0.05 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 256 \ + --server.max_prefill_tokens 128 \ + --server.disable_radix_cache \ + --server.disable_cuda_graph \ + --server.log_level debug + +BF16 原生模型 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +BF16 或 FP16 原生模型不需要设置 ``--quantization.method``: + +.. code-block:: bash + + cd + + python3 -m pymllm.server.launch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype bfloat16 \ + --server.host 0.0.0.0 \ + --server.port 30000 \ + --server.attention_backend auto \ + --server.mem_fraction_static 0.05 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 256 \ + --server.max_prefill_tokens 128 \ + --server.disable_radix_cache \ + --server.log_level info + +常用参数 +---------------------------------------- + +.. list-table:: + :header-rows: 1 + + * - 参数 + - 说明 + * - ``--server.model_path`` + - 模型权重目录,通常是 HuggingFace 或 ModelScope 格式。 + * - ``--server.tokenizer_path`` + - tokenizer 目录;不设置时默认等于 ``model_path``。 + * - ``--server.dtype`` + - 模型运行 dtype,可选 ``auto``、``float16``、``bfloat16``、``float32``。 + * - ``--quantization.method compressed-tensors`` + - 启用 ``compressed-tensors`` 权重加载与线性层执行路径。 + * - ``--server.max_running_requests`` + - 同时运行的请求数。Jetson 小显存环境下通常从 ``1`` 开始调试。 + * - ``--server.max_total_tokens`` + - KV cache token pool 的总容量上限。 + * - ``--server.max_prefill_tokens`` + - 单轮 prefill 可处理的 token 上限。 + * - ``--server.disable_radix_cache`` + - 关闭 Radix Cache,改用 ``ChunkCache``。 + * - ``--server.disable_cuda_graph`` + - 关闭 decode CUDA Graph,便于调试动态路径。 + +OpenAI-compatible 请求 +---------------------------------------- + +健康检查: + +.. code-block:: bash + + curl -s --noproxy '*' http://127.0.0.1:30000/v1/models ; echo + +文本请求: + +.. code-block:: bash + + curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "你好,只回复:ok"}], + "max_tokens": 8, + "temperature": 0.0, + "stream": false + }' ; echo + +图文请求中,图片路径需要是容器内可访问的绝对路径,不要带 ``file://`` 前缀: + +.. code-block:: bash + + cat > /tmp/mm_req_path.json <<'JSON' + { + "model": "default", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "请描述这张图片。"}, + {"type": "image_url", "image_url": {"url": "/workspace/test.png"}} + ] + } + ], + "max_tokens": 128, + "temperature": 0.0, + "stream": false + } + JSON + + curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + --data @/tmp/mm_req_path.json ; echo + +开发与测试 +---------------------------------------- + +常用单元测试: + +.. code-block:: bash + + pytest pymllm/tests/test_compressed_tensors_config.py -q + pytest pymllm/tests/test_compressed_tensors_runtime.py -q + pytest pymllm/tests/test_qwen3_model_registry.py -q + pytest pymllm/tests/test_qwen3_weight_loading.py -q + pytest pymllm/tests/test_qwen3_forward_timing.py -q + pytest mllm-kernel/tests/test_int8_scaled_mm_cutlass.py -q + +模型级 benchmark: +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``bench_one_batch`` 是对齐 SGLang 口径的低层离线 benchmark。它直接初始化 +``pymllm.executor.model_runner.ModelRunner``,绕过 HTTP server、tokenizer 进程、 +scheduler 进程和 detokenizer 进程,用 synthetic text-only token ids 测一次静态 +prefill,再测逐 token decode。该工具适合分析模型 forward、KV cache、attention、 +CUDA Graph 与量化 kernel 的模型级开销,不代表在线服务的 TTFT / ITL / E2E 指标。 + +典型用法: + +.. code-block:: bash + + PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype float16 \ + --quantization.method compressed-tensors \ + --server.mem_fraction_static 0.1 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 2048 \ + --server.disable_radix_cache \ + --server.log_level info \ + --run-name qwen3vl_w8a8_bench_one_batch \ + --batch-size 1 \ + --input-len 256 512 1024 \ + --output-len 128 \ + --result-filename /tmp/pymllm_bench_one_batch.jsonl + +其中 ``--batch-size``、``--input-len`` 和 ``--output-len`` 都支持多个值,脚本会遍历 +所有组合并向 JSONL 文件追加结果。``output_len`` 采用 SGLang 的总输出 token 语义: +prefill 后已得到第一个 next token,后续 decode loop 执行 ``output_len - 1`` 步。 + +执行结构: + +.. code-block:: text + + pymllm.bench_one_batch CLI + | + |-- parse GlobalConfig args and BenchArgs + |-- load HuggingFace AutoConfig into cfg.model.hf_config + | + |-- ModelRunner.initialize() + | |-- load model and quantization config + | |-- initialize KV pools and attention backend + | |-- optionally capture decode CUDA Graph + | + |-- warmup once + | + |-- for each (batch_size, input_len, output_len): + | + |-- clear req_to_token_pool and token_to_kv_pool_allocator + |-- build synthetic input_ids + |-- prefill: + | allocate request slots and KV slots + | write prompt KV mapping + | prepare ForwardBatch(EXTEND) + | synchronize, run forward + sampling, synchronize + | + |-- decode loop: + allocate one KV slot per request + write current token mapping + prepare ForwardBatch(DECODE) + synchronize, run forward + sampling, synchronize + update seq_lens and next token ids + | + |-- append JSONL result rows + +Profile 辅助入口: +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``bench_one_batch`` 保留了基于 ``torch.profiler`` 的 profile 参数,主要用于本地 +kernel timeline 分析。当前公开 benchmark 记录没有使用 profile 结果,因此它不作为标准 +性能数据口径的一部分。使用前建议先用较小的 ``input_len`` / ``output_len`` 做一次 +trace 生成验证,再扩大到正式 case。 + +.. code-block:: bash + + PYMLLM_TORCH_PROFILER_DIR=/tmp \ + PYTHONPATH="$PWD:$PWD/mllm-kernel" python3 -m pymllm.bench_one_batch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype bfloat16 \ + --server.mem_fraction_static 0.1 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 2048 \ + --server.log_level info \ + --batch-size 1 \ + --input-len 256 \ + --output-len 128 \ + --profile \ + --profile-stage decode \ + --profile-steps 1 + +已知限制 +---------------------------------------- + +- W8A8 CUTLASS 当前通过 JIT 编译,首次启动有明显编译开销。 +- W8A8 激活量化使用 Triton kernel;decode 小 batch 下固定量化开销仍是后续优化点。 +- Qwen3-VL 的 ViT、``lm_head``、embedding 和 LayerNorm 不在当前 W8A8 量化范围内。 +- 当前文档中的 Jetson 性能与稳定性结论主要来自 Orin NX / SM87,需要在其他 GPU 上重新验证。 +- OpenAI-compatible API 的服务级指标和 ``bench_one_batch`` 的模型级指标口径不同,不应直接混用。 diff --git a/mllm-kernel/benchmarks/bench_int8_scaled_mm.py b/mllm-kernel/benchmarks/bench_int8_scaled_mm.py new file mode 100644 index 000000000..441494616 --- /dev/null +++ b/mllm-kernel/benchmarks/bench_int8_scaled_mm.py @@ -0,0 +1,151 @@ +"""Benchmark int8_scaled_mm implementations. + +Covers torch._int_mm and the CUTLASS W8A8 kernel. + +Usage: + python benchmarks/bench_int8_scaled_mm.py +""" +from __future__ import annotations + +import time +from typing import Callable, Optional + +import torch + + +# --------------------------------------------------------------------------- +# Reference / backend implementations +# --------------------------------------------------------------------------- + +def _torch_int_mm_scaled( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + scales_a: torch.Tensor, + scales_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """torch._int_mm + scale dequant reference backend.""" + m = mat_a.shape[0] + if m <= 16: + padded = torch.zeros((17, mat_a.shape[1]), device=mat_a.device, dtype=torch.int8) + padded[:m].copy_(mat_a) + out_i32 = torch._int_mm(padded, mat_b)[:m] + else: + out_i32 = torch._int_mm(mat_a, mat_b) + out = out_i32.to(torch.float32) + out.mul_(scales_a.view(-1, 1)) + out.mul_(scales_b.view(1, -1)) + out = out.to(out_dtype) + if bias is not None: + out.add_(bias) + return out + + +def _try_load_cutlass_kernel(): + try: + from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import int8_scaled_mm + return int8_scaled_mm + except Exception: + return None + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + +def bench_fn( + fn: Callable, + args: tuple, + kwargs: dict, + warmup: int = 5, + repeat: int = 20, +) -> float: + """Returns median latency in ms.""" + for _ in range(warmup): + fn(*args, **kwargs) + torch.cuda.synchronize() + + times = [] + for _ in range(repeat): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn(*args, **kwargs) + torch.cuda.synchronize() + t1 = time.perf_counter() + times.append((t1 - t0) * 1e3) + times.sort() + return times[len(times) // 2] + + +def run_benchmarks(): + device = "cuda" + out_dtype = torch.float16 + + # Shapes representative of Qwen3-VL-2B linear layers + shapes = [ + # (M, K, N) — M=seq_len, K=in_features, N=out_features + (1, 2048, 2048), # decode, hidden->hidden + (1, 2048, 6144), # decode, hidden->3*hidden (QKV) + (8, 2048, 6144), # small batch + (16, 2048, 6144), # boundary (torch._int_mm M<=16 padding) + (32, 2048, 6144), # medium batch + (93, 2048, 6144), # typical prefill + (128, 2048, 6144), # larger prefill + (93, 6144, 2048), # prefill, wide->narrow (down_proj) + ] + + backends = {} + + # Backend: torch._int_mm + backends["torch._int_mm"] = _torch_int_mm_scaled + + # Backend: CUTLASS + cutlass_fn = _try_load_cutlass_kernel() + if cutlass_fn is not None: + backends["cutlass"] = cutlass_fn + + print(f"{'Shape':>20s}", end="") + for name in backends: + print(f" {name:>16s}", end="") + print() + print("-" * (20 + 18 * len(backends))) + + results = [] + for M, K, N in shapes: + torch.manual_seed(42) + mat_a = torch.randint(-127, 128, (M, K), dtype=torch.int8, device=device) + mat_b = torch.randint(-127, 128, (K, N), dtype=torch.int8, device=device) + scales_a = torch.rand(M, dtype=torch.float32, device=device) + 0.01 + scales_b = torch.rand(N, dtype=torch.float32, device=device) + 0.01 + + # CUTLASS needs col-major B + mat_b_colmaj = mat_b.t().contiguous().t() + + row = {"shape": f"({M},{K},{N})"} + print(f"{row['shape']:>20s}", end="") + + for name, fn in backends.items(): + kwargs = dict(out_dtype=out_dtype) + b_arg = mat_b_colmaj if name == "cutlass" else mat_b + try: + ms = bench_fn(fn, (mat_a, b_arg, scales_a, scales_b), kwargs) + row[name] = f"{ms:.3f}" + print(f" {ms:>13.3f} ms", end="") + except Exception as e: + row[name] = f"ERR: {e}" + print(f" {'ERROR':>16s}", end="") + + print() + results.append(row) + + return results + + +if __name__ == "__main__": + print("=" * 60) + print("INT8 Scaled MM Benchmark") + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"SM: {torch.cuda.get_device_capability(0)}") + print("=" * 60) + run_benchmarks() diff --git a/mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py b/mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py new file mode 100644 index 000000000..534b1d16a --- /dev/null +++ b/mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py @@ -0,0 +1,181 @@ +"""Kernel-level benchmark: W4A16 (GPTQ-Marlin) vs W8A8 (Triton quant + CUTLASS GEMM). + +Isolates kernel performance from serving framework overhead. +Shapes are from Qwen3-VL-2B linear layers. + +Usage: + cd /workspace/.worktrees/pymllm-qwen3-vl-w8a8 + python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py +""" +from __future__ import annotations + +import time +from typing import Callable + +import torch + + +# --------------------------------------------------------------------------- +# Benchmark utility +# --------------------------------------------------------------------------- + +def bench(fn: Callable, warmup: int = 5, repeat: int = 20) -> float: + """Returns median latency in ms.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + times = [] + for _ in range(repeat): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + torch.cuda.synchronize() + times.append((time.perf_counter() - t0) * 1e3) + times.sort() + return times[len(times) // 2] + + +# --------------------------------------------------------------------------- +# W8A8 kernel loaders +# --------------------------------------------------------------------------- + +def load_cutlass_mm(): + from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import int8_scaled_mm + return int8_scaled_mm + + +def load_triton_quant(): + from pymllm.quantization.kernels.int8_activation_triton import per_token_quant_int8 + return per_token_quant_int8 + + +# --------------------------------------------------------------------------- +# W4A16 kernel loader +# --------------------------------------------------------------------------- + +def load_marlin(): + from mllm_kernel.cuda.jit import gptq_marlin_gemm, gptq_marlin_repack + from pymllm.quantization.methods.compressed_tensors import ( + marlin_make_workspace, + marlin_make_empty_g_idx, + marlin_permute_scales, + SCALAR_TYPE_UINT4B8, + ) + return gptq_marlin_gemm, gptq_marlin_repack, marlin_make_workspace, \ + marlin_make_empty_g_idx, marlin_permute_scales, SCALAR_TYPE_UINT4B8 + + +def prepare_marlin_weights(K: int, N: int, group_size: int, device: str): + """Create fake W4A16 weights in Marlin format for benchmarking.""" + gptq_marlin_gemm, gptq_marlin_repack, marlin_make_workspace, \ + marlin_make_empty_g_idx, marlin_permute_scales, SCALAR_TYPE_UINT4B8 = load_marlin() + + pack_factor = 8 # 32 / 4 bits + w_packed = torch.randint( + 0, 2**31, (N, K // pack_factor), dtype=torch.int32, device=device, + ) + w_scale = ( + torch.rand(N, K // group_size, dtype=torch.float16, device=device) + 0.01 + ) + + repacked = gptq_marlin_repack( + w_packed.t().contiguous(), + perm=torch.empty(0, dtype=torch.int32, device=device), + size_k=K, size_n=N, num_bits=4, + ) + scales_perm = marlin_permute_scales( + w_scale.t().contiguous(), size_k=K, size_n=N, group_size=group_size, + ) + workspace = marlin_make_workspace(torch.device(device)) + g_idx = marlin_make_empty_g_idx(torch.device(device)) + + return repacked, scales_perm, workspace, g_idx, SCALAR_TYPE_UINT4B8 + + +# --------------------------------------------------------------------------- +# Main benchmark +# --------------------------------------------------------------------------- + +def run_benchmarks(): + device = "cuda" + group_size = 32 + + shapes = [ + # (M, K, N, description) + (1, 2048, 6144, "QKV proj"), + (1, 2048, 2048, "O proj"), + (1, 6144, 2048, "down proj"), + (93, 2048, 6144, "QKV proj"), + (93, 2048, 2048, "O proj"), + (93, 6144, 2048, "down proj"), + (128, 2048, 6144, "QKV proj"), + ] + + # Load kernels + cutlass_mm = load_cutlass_mm() + triton_quant = load_triton_quant() + gptq_marlin_gemm = load_marlin()[0] + + # Header + print(f"{'Shape':<22s} {'':>6s} {'W4A16':>8s} {'W8A8':>8s} {'W8A8':>8s} {'W8A8':>8s}") + print(f"{'(M, K, N)':<22s} {'desc':>6s} {'Marlin':>8s} {'quant':>8s} {'GEMM':>8s} {'total':>8s}") + print("-" * 72) + + for M, K, N, desc in shapes: + torch.manual_seed(42) + + # ----- W8A8 setup ----- + x_fp16 = torch.randn(M, K, device=device, dtype=torch.float16) + w_int8_col = torch.randint( + -127, 128, (N, K), dtype=torch.int8, device=device, + ).t() # (K, N) col-major, stride(0)==1 + w_scale_f32 = torch.rand(N, dtype=torch.float32, device=device) * 0.01 + + # Pre-quantize for GEMM-only bench + x_q, x_s = triton_quant(x_fp16) + + ms_quant = bench(lambda: triton_quant(x_fp16)) + ms_gemm = bench(lambda: cutlass_mm(x_q, w_int8_col, x_s, w_scale_f32, torch.float16)) + ms_w8a8 = ms_quant + ms_gemm + + # ----- W4A16 setup ----- + repacked, scales_perm, workspace, g_idx, scalar_type = \ + prepare_marlin_weights(K, N, group_size, device) + x_marlin = torch.randn(M, K, device=device, dtype=torch.float16) + + def run_marlin(): + return gptq_marlin_gemm( + a=x_marlin, c=None, b_q_weight=repacked, b_scales=scales_perm, + global_scale=None, b_zeros=g_idx, g_idx=g_idx, perm=g_idx, + workspace=workspace, b_q_type_id=scalar_type.id, + size_m=M, size_n=N, size_k=K, is_k_full=True, + use_fp32_reduce=True, is_zp_float=False, + ) + + ms_marlin = bench(run_marlin) + + # ----- Print ----- + tag = "decode" if M <= 8 else "prefill" + print( + f" ({M:>3},{K:>4},{N:>4}) {desc:<8s}" + f" {ms_marlin:>7.3f} {ms_quant:>7.3f} {ms_gemm:>7.3f} {ms_w8a8:>7.3f}" + ) + + # Summary + print() + print("W4A16 Marlin : gptq_marlin_gemm (int4 weight * fp16 activation, 1 kernel)") + print("W8A8 quant : Triton per_token_quant_int8 (fp16 -> int8, 1 kernel)") + print("W8A8 GEMM : CUTLASS int8_scaled_mm (int8 * int8, fused scale, 1 kernel)") + print("W8A8 total : quant + GEMM (2 kernel launches)") + print() + print("Key insight: W8A8 GEMM alone is faster than W4A16 Marlin,") + print("but activation quantization overhead makes W8A8 total slower at decode (M=1).") + + +if __name__ == "__main__": + print("=" * 72) + print("W4A16 vs W8A8 Kernel Benchmark") + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"SM: {torch.cuda.get_device_capability(0)}") + print("=" * 72) + run_benchmarks() diff --git a/mllm-kernel/cmake/CPM.cmake b/mllm-kernel/cmake/CPM.cmake index 3bfca27ba..ce36c9400 100644 --- a/mllm-kernel/cmake/CPM.cmake +++ b/mllm-kernel/cmake/CPM.cmake @@ -1,29 +1,30 @@ # SPDX-License-Identifier: MIT -# Download CPM.cmake on-the-fly -# This is a lightweight bootstrap that downloads the actual CPM.cmake +# Prefer the vendored CPM.cmake from the parent mllm repo. This avoids relying +# on network access for editable builds while keeping standalone fallback logic. set(CPM_VERSION 0.42.0) set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_VERSION}.cmake") +set(PARENT_CPM "${CMAKE_CURRENT_LIST_DIR}/../../cmake/CPM.cmake") -if(NOT EXISTS ${CPM_DOWNLOAD_LOCATION}) - message(STATUS "Downloading CPM.cmake v${CPM_VERSION}...") - file(DOWNLOAD - https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_VERSION}/CPM.cmake - ${CPM_DOWNLOAD_LOCATION} - STATUS download_status - ) - list(GET download_status 0 download_status_code) - if(NOT download_status_code EQUAL 0) - # Fallback: copy from parent mllm project if available - set(PARENT_CPM "${CMAKE_CURRENT_SOURCE_DIR}/../cmake/CPM.cmake") - if(EXISTS ${PARENT_CPM}) - message(STATUS "Using CPM.cmake from parent project") - file(COPY ${PARENT_CPM} DESTINATION "${CMAKE_BINARY_DIR}/cmake/") - file(RENAME "${CMAKE_BINARY_DIR}/cmake/CPM.cmake" ${CPM_DOWNLOAD_LOCATION}) - else() +if(EXISTS "${PARENT_CPM}") + include("${PARENT_CPM}") +else() + if(NOT EXISTS "${CPM_DOWNLOAD_LOCATION}") + message(STATUS "Downloading CPM.cmake v${CPM_VERSION}...") + file(DOWNLOAD + https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_VERSION}/CPM.cmake + "${CPM_DOWNLOAD_LOCATION}" + STATUS download_status + ) + list(GET download_status 0 download_status_code) + if(NOT download_status_code EQUAL 0) message(FATAL_ERROR "Failed to download CPM.cmake") endif() endif() + + include("${CPM_DOWNLOAD_LOCATION}") endif() -include(${CPM_DOWNLOAD_LOCATION}) +if(NOT COMMAND CPMAddPackage) + message(FATAL_ERROR "CPM.cmake loaded, but CPMAddPackage is not available") +endif() diff --git a/mllm-kernel/include/mllm_kernel/scalar_type.hpp b/mllm-kernel/include/mllm_kernel/scalar_type.hpp index def41a12b..bec1c46db 100644 --- a/mllm-kernel/include/mllm_kernel/scalar_type.hpp +++ b/mllm-kernel/include/mllm_kernel/scalar_type.hpp @@ -6,7 +6,7 @@ #include #endif -namespace host { +namespace mllm_kernel::host { // // ScalarType can represent a wide range of floating point and integer types, @@ -257,4 +257,6 @@ static inline constexpr auto kFloat16 = kHalf; static inline constexpr auto kBFloat16 = kFE8M7; static inline constexpr auto kFloat16Id = kFloat16.id(); -} // namespace host +} // namespace mllm_kernel::host + +namespace host = ::mllm_kernel::host; diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h new file mode 100644 index 000000000..9f85bee28 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -0,0 +1,309 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h + +#pragma once + +#include +#include + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template < + typename ThreadblockShape_, + int ThreadCount, + typename ScaleTileIterator_, + typename OutputTileIterator_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementwiseFunctor_, + bool UseMasking_ = false> +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments( + typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_alpha_, + int64_t batch_stride_C_, + int64_t batch_stride_D_) + : elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; + + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) {} + }; + + /// Shared storage + struct SharedStorage {}; + + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const with_bias_; + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol( + Params const& params, + SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + bool with_bias, + bool per_token_quant, + bool per_channel_quant, + AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + with_bias_(with_bias), + per_token_quant_(per_token_quant), + per_channel_quant_(per_channel_quant), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), + iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) { + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); + } + + if (with_bias_) { + iterator_C_.load(fragment_C_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) { + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + if (with_bias_) { + NumericArrayConverter bias_converter; + OutputVector bias = reinterpret_cast(&fragment_C_)[column_idx]; + result = bias_accumulator_(result, bias_converter(bias)); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment bias_accumulator_(ComputeFragment const& accum, ComputeFragment const& bias) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputVector::kElements; ++i) { + result[i] = accum[i] + bias[i]; + } + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h new file mode 100644 index 000000000..b58d84318 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -0,0 +1,356 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h +#pragma once + +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST( + " grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST( + "GemmUniversalBaseCompat::initialize() - workspace " << workspace + << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h new file mode 100644 index 000000000..905d11ba2 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -0,0 +1,492 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h + +#pragma once + +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function + > +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} + + /// constructs an arguments structure + Arguments( + GemmCoord problem_size_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, + TensorRefC ref_C_, + TensorRefC ref_D_, + typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(GemmUniversalMode::kGemm), + problem_size(problem_size_), + batch_count(1), + ref_A(ref_A_), + ref_B(ref_B_), + ref_alpha_col(ref_alpha_col_), + ref_alpha_row(ref_alpha_row_), + ref_C(ref_C_), + ref_D(ref_D_), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_D(0), + epilogue_visitor(epilogue_visitor_) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), + params_A(0), + params_B(0), + params_alpha_col(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_alpha_col(nullptr), + ptr_alpha_row(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0) {} + + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_alpha_col(args.ref_alpha_col.layout()), + params_alpha_row(args.ref_alpha_col.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + ptr_alpha_col(args.ref_alpha_col.data()), + ptr_alpha_row(args.ref_alpha_row.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if ( + platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if ( + platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if ( + platform::is_same>::value || + platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + bool with_bias = true; + if (params.ptr_C == nullptr) { + with_bias = false; + } + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_alpha_col, + params.params_C, + params.params_D, + with_bias, + true, + true, + params.ptr_alpha_row, + params.ptr_alpha_col, + params.ptr_C, + params.ptr_D, + threadblock_offset, + blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + run_kernel(params, shared_storage); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu new file mode 100644 index 000000000..89cfeff4c --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/int8/int8_scaled_mm_cutlass.cu @@ -0,0 +1,404 @@ +/** + * CUTLASS INT8 Scaled MatMul for SM80+ (Ampere). + * + * Ported from sglang sgl-kernel/csrc/gemm/int8_gemm_kernel.cu + * Adapted for mllm-kernel with SM87 (Jetson Orin) support. + * + * Only includes CUTLASS 2.x paths (SM80/87/89). No SM90 (Hopper) support. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" +#include "cutlass_extensions/gemm/gemm_universal_base_compat.h" +#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h" + +// --------------------------------------------------------------------------- +// Utility +// --------------------------------------------------------------------------- + +inline int getSMVersion() { + int device{-1}; + cudaGetDevice(&device); + int sm_major = 0, sm_minor = 0; + cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); + return sm_major * 10 + sm_minor; +} + +// --------------------------------------------------------------------------- +// Core CUTLASS GEMM template (CUTLASS 2.x with per-row/col scale epilogue) +// --------------------------------------------------------------------------- + +template < + typename ElementOutput, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + int NumStages> +void cutlass_int8_scaled_mm( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + + using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementInputA, ElementInputB, + ElementOutput, ElementCompute>; + using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< + ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, + ElementInputB, cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, + ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, + OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, NumStages, + true, typename DefaultGemmConf::Operator>::GemmKernel; + + using AlphaColTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIterator< + cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< + typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape, + typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count, + GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads, + GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, + cutlass::sizeof_bits::value>, + ElementCompute>; + + using EpilogueVisitor = + typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol< + ThreadblockShape, + GemmKernel_::kThreadCount, + AlphaColTileIterator, + typename GemmKernel_::Epilogue::OutputTileIterator, + ElementAccumulator, ElementCompute, EpilogueOutputOp>; + + using Epilogue = typename cutlass::epilogue::threadblock:: + EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue; + + using GemmKernel = cutlass::gemm::kernel::GemmWithEpilogueVisitor< + typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>; + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + Gemm gemm_op; + + int m = mat_a.size(0); + int k = mat_a.size(1); + int n = mat_b.size(1); + + auto a_ptr = static_cast(mat_a.data_ptr()); + auto b_ptr = static_cast(mat_b.data_ptr()); + auto o_ptr = static_cast(out.data_ptr()); + auto a_s_ptr = static_cast(scales_a.data_ptr()); + auto b_s_ptr = static_cast(scales_b.data_ptr()); + + int64_t lda = mat_a.stride(0); + int64_t ldb = mat_b.stride(1); + int64_t ldd = out.stride(0); + + ElementOutput* bias_ptr = nullptr; + int64_t ldc = 0; + if (bias) { + bias_ptr = static_cast(bias->data_ptr()); + } + + typename EpilogueOutputOp::Params linearScalingParams; + typename EpilogueVisitor::Arguments visitor_args{linearScalingParams}; + + typename Gemm::Arguments args{ + {m, n, k}, + {a_ptr, lda}, {b_ptr, ldb}, + {b_s_ptr, 0}, {a_s_ptr, 0}, + {bias_ptr, ldc}, {o_ptr, ldd}, + visitor_args}; + + auto workspace = torch::empty( + gemm_op.get_workspace_size(args), + torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + + auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK( + can_implement == cutlass::Status::kSuccess, + "CUTLASS can_implement failed: ", + cutlassGetStatusString(can_implement)); + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK( + status == cutlass::Status::kSuccess, + "CUTLASS execution failed: ", + cutlassGetStatusString(status)); +} + +// --------------------------------------------------------------------------- +// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to: +// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh +// --------------------------------------------------------------------------- + +template +void sm89_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + int n = mat_b.size(1); + if (m <= 16) { + if (n <= 8192) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, 4>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 32) { + if (n <= 8192) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, 4>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + if (n <= 8192) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 3>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + if (n <= 8192) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, 3>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else if (n <= 16384) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 256) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 3>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else if (n <= 8192) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else if (n <= 16384) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 3>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } +} + +// --------------------------------------------------------------------------- +// SM80 dispatch (160K shared memory, for SM80/SM87) +// --------------------------------------------------------------------------- + +template +void sm80_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + int n = mat_b.size(1); + if (m <= 16) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, 6>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 32) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, 6>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 128 && n < 8192) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, 5>(out, mat_a, mat_b, scales_a, scales_b, bias); + } +} + +// --------------------------------------------------------------------------- +// Entry point +// --------------------------------------------------------------------------- + +torch::Tensor int8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const std::string& out_dtype_str, + const c10::optional& bias) { + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be 2D"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be 2D"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be row-major"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be column-major"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "shape mismatch"); + TORCH_CHECK(mat_a.size(1) % 16 == 0, "K must be multiple of 16"); + TORCH_CHECK(mat_b.size(1) % 8 == 0, "N must be multiple of 8"); + TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8"); + TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8"); + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "scales_a size mismatch"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "scales_b size mismatch"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be fp32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be fp32"); + + torch::Dtype out_dtype; + if (out_dtype_str == "float16") { + out_dtype = torch::kHalf; + } else if (out_dtype_str == "bfloat16") { + out_dtype = torch::kBFloat16; + } else { + TORCH_CHECK(false, "out_dtype must be 'float16' or 'bfloat16', got: ", out_dtype_str); + } + + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "bias size mismatch"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match out_dtype"); + } + + auto out = torch::empty( + {mat_a.size(0), mat_b.size(1)}, + mat_a.options().dtype(out_dtype)); + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ArchTag = cutlass::arch::Sm80; + + // SM86/SM89 have smaller shared memory and use sglang's SM89 tile shapes. + // SM87 (Jetson Orin) has 164K smem, same as SM80, so it stays on SM80. + int sm_version = getSMVersion(); + + if (sm_version >= 80 && sm_version < 90) { + if (sm_version == 86 || sm_version == 89) { + if (out_dtype == torch::kBFloat16) { + sm89_dispatch_shape( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm89_dispatch_shape( + out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else { + if (out_dtype == torch::kBFloat16) { + sm80_dispatch_shape( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm80_dispatch_shape( + out, mat_a, mat_b, scales_a, scales_b, bias); + } + } + } else { + TORCH_CHECK(false, "Unsupported SM version: ", sm_version, ". Requires SM80-SM89."); + } + + return out; +} + +// --------------------------------------------------------------------------- +// PyBind11 binding +// --------------------------------------------------------------------------- + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("int8_scaled_mm", &int8_scaled_mm, + "CUTLASS INT8 scaled matmul with per-row/col scaling", + py::arg("mat_a"), py::arg("mat_b"), + py::arg("scales_a"), py::arg("scales_b"), + py::arg("out_dtype"), py::arg("bias") = py::none()); +} diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh index 483ff5fc5..5474d1b9f 100644 --- a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh @@ -1,13 +1,10 @@ #pragma once #include +#include #include -// Bridge the mllm_kernel::host namespace to the `host` namespace expected by -// Marlin code (originally from sglang). -namespace host = ::mllm_kernel::host; - namespace device::marlin { // Marlin params diff --git a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py index 1fe41f560..94d8b7144 100644 --- a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py +++ b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py @@ -2,11 +2,13 @@ from .awq_marlin_repack import awq_marlin_repack from .gdn_decode import gdn_decode from .gptq_marlin import gptq_marlin_gemm +from .gptq_marlin_repack import gptq_marlin_repack from .store_cache import can_use_store_cache, store_cache __all__ = [ "add_constant", "awq_marlin_repack", + "gptq_marlin_repack", "can_use_store_cache", "gdn_decode", "gptq_marlin_gemm", diff --git a/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py b/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py index 9eeefa765..1b33842ca 100644 --- a/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py +++ b/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py @@ -29,13 +29,14 @@ @cache_once def _make_gptq_marlin_gemm_kernel(dtype: torch.dtype): """JIT-compile the GPTQ Marlin GEMM kernel for a specific dtype.""" - args = make_cpp_args(dtype) + cpp_args = make_cpp_args(dtype) @jit( - args=args, + args=[dtype], device="cuda", cuda_files=["gemm/marlin/gptq_marlin.cuh"], - cuda_wrappers=[("gptq_marlin_gemm", f"gptq_marlin_gemm<{args}>")], + cpp_wrappers=[], + cuda_wrappers=[("gptq_marlin_gemm", f"gptq_marlin_gemm<{cpp_args}>")], func_name="gptq_marlin_gemm", ) def _kernel( diff --git a/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py b/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py new file mode 100644 index 000000000..5869b7eb3 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py @@ -0,0 +1,75 @@ +"""GPTQ/Compressed-Tensors Marlin repack CUDA JIT kernel.""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +def _normalize_perm( + perm: Optional[torch.Tensor], size_k: int, device: torch.device +) -> torch.Tensor: + if perm is None or perm.numel() == 0: + return torch.empty(0, dtype=torch.int32, device=device) + if perm.device != device: + raise ValueError("perm must live on the same device as b_q_weight") + if perm.dtype != torch.int32: + raise ValueError("perm must be int32") + if perm.numel() != size_k: + raise ValueError("perm length must equal size_k") + if torch.any(perm < 0) or torch.any(perm >= size_k): + raise ValueError("perm values must be in [0, size_k)") + return perm.contiguous() + + +@cache_once +def _make_gptq_marlin_repack_kernel(): + """JIT-compile the GPTQ repack kernel.""" + + @jit( + args=[], + device="cuda", + cuda_files=["gemm/marlin/gptq_marlin_repack.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("gptq_marlin_repack", "gptq_marlin_repack")], + func_name="gptq_marlin_repack", + ) + def _kernel( + compiled_module, + b_q_weight: torch.Tensor, + perm: torch.Tensor, + out: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + ) -> None: + compiled_module.gptq_marlin_repack( + b_q_weight, perm, out, size_k, size_n, num_bits + ) + + return _kernel + + +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: Optional[torch.Tensor], + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + """Repack GPTQ/Compressed-Tensors weights into Marlin layout.""" + + pack_factor = 32 // num_bits + tile_size = 16 + out = torch.empty( + (size_k // tile_size, size_n * tile_size // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + kernel = _make_gptq_marlin_repack_kernel() + perm_t = _normalize_perm(perm, size_k, b_q_weight.device) + kernel(b_q_weight, perm_t, out, size_k, size_n, num_bits) + return out diff --git a/mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py b/mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py new file mode 100644 index 000000000..4a24be515 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/int8_scaled_mm_cutlass.py @@ -0,0 +1,142 @@ +"""CUTLASS-based INT8 scaled matmul for SM80+ (Ampere). + +JIT-compiled via torch.utils.cpp_extension.load on first use. +Compiled module is cached per GPU arch at +~/.cache/mllm_kernel/cutlass_int8_scaled_mm/sm_XX/. +""" +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional + +import torch + +_module = None +_module_arch = None +_CSRC_DIR = Path(__file__).resolve().parent.parent / "csrc" +_CUTLASS_INC = None + + +def _find_cutlass_include() -> str: + """Find CUTLASS include path.""" + # Check environment variable + env_path = os.environ.get("CUTLASS_HOME") + if env_path and os.path.isdir(os.path.join(env_path, "include", "cutlass")): + return os.path.join(env_path, "include") + + # Check flashinfer bundled copy + try: + import flashinfer + fi_path = os.path.join( + os.path.dirname(flashinfer.__file__), + "data", "cutlass", "include", + ) + if os.path.isdir(os.path.join(fi_path, "cutlass")): + return fi_path + except ImportError: + pass + + # Check common system paths + for p in [ + "/usr/local/include", + "/usr/include", + "/usr/local/cuda/include", + ]: + if os.path.isdir(os.path.join(p, "cutlass")): + return p + + raise RuntimeError( + "CUTLASS include directory not found. Set CUTLASS_HOME or install " + "flashinfer (which bundles CUTLASS headers)." + ) + + +def _current_cuda_arch() -> str: + major, minor = torch.cuda.get_device_capability() + arch = f"sm_{major}{minor}" + if major != 8: + raise RuntimeError( + f"CUTLASS int8_scaled_mm supports SM80-SM89, got {arch}" + ) + return arch + + +def _load_module(): + global _module, _module_arch, _CUTLASS_INC + + cuda_arch = _current_cuda_arch() + if _module is not None and _module_arch == cuda_arch: + return _module + + from torch.utils.cpp_extension import load + + _CUTLASS_INC = _find_cutlass_include() + + cache_dir = os.path.expanduser( + os.path.join("~/.cache/mllm_kernel/cutlass_int8_scaled_mm", cuda_arch) + ) + os.makedirs(cache_dir, exist_ok=True) + + source = str(_CSRC_DIR / "gemm" / "int8" / "int8_scaled_mm_cutlass.cu") + + _module = load( + name=f"mllm_cutlass_int8_scaled_mm_{cuda_arch}", + sources=[source], + extra_include_paths=[ + _CUTLASS_INC, + str(_CSRC_DIR), + ], + extra_cuda_cflags=[ + f"-arch={cuda_arch}", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "--expt-relaxed-constexpr", + "-std=c++17", + "-diag-suppress=20013", + "-diag-suppress=20015", + "-O3", + ], + build_directory=cache_dir, + verbose=False, + ) + _module_arch = cuda_arch + return _module + + +def int8_scaled_mm( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + scales_a: torch.Tensor, + scales_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """CUTLASS INT8 scaled matmul: out = (mat_a @ mat_b) * scales_a * scales_b + bias. + + Args: + mat_a: [M, K] int8, row-major (contiguous) + mat_b: [K, N] int8, column-major (stride(0)==1) + scales_a: [M] float32, per-row scale for activations + scales_b: [N] float32, per-column scale for weights + out_dtype: torch.float16 or torch.bfloat16 + bias: optional [N] tensor, same dtype as out_dtype + + Returns: + [M, N] tensor of out_dtype + """ + if out_dtype == torch.float16: + dtype_str = "float16" + elif out_dtype == torch.bfloat16: + dtype_str = "bfloat16" + else: + raise ValueError( + f"out_dtype must be torch.float16 or torch.bfloat16, got {out_dtype}" + ) + + mod = _load_module() + + # scales_a from Triton quant is (M,1) float32 — flatten to (M,) + if scales_a.dim() == 2: + scales_a = scales_a.squeeze(-1) + + return mod.int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, dtype_str, bias) diff --git a/mllm-kernel/tests/test_gptq_marlin.py b/mllm-kernel/tests/test_gptq_marlin.py new file mode 100644 index 000000000..7f2bcba72 --- /dev/null +++ b/mllm-kernel/tests/test_gptq_marlin.py @@ -0,0 +1,151 @@ +import pytest +import torch +import torch.nn.functional as F + +from mllm_kernel.cuda.jit import gptq_marlin_gemm, gptq_marlin_repack + + +CUDA_ONLY = pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires CUDA" +) + + +def _compute_scalar_type_id( + exponent: int, + mantissa: int, + signed: bool, + bias: int, + finite_values_only: bool = False, + nan_repr: int = 1, +) -> int: + bit_offset = 0 + result = 0 + for value, width in [ + (exponent, 8), + (mantissa, 8), + (signed, 1), + (bias, 32), + (finite_values_only, 1), + (nan_repr, 8), + ]: + result |= (int(value) & ((1 << width) - 1)) << bit_offset + bit_offset += width + return result + + +SCALAR_TYPE_UINT4B8_ID = _compute_scalar_type_id(0, 4, False, 8) + + +def _pack_checkpoint_weight(q_weight: torch.Tensor, num_bits: int) -> torch.Tensor: + pack_factor = 32 // num_bits + size_n, size_k = q_weight.shape + packed = torch.zeros( + (size_n, size_k // pack_factor), + dtype=torch.int32, + device=q_weight.device, + ) + for i in range(pack_factor): + packed.bitwise_or_(q_weight[:, i::pack_factor].int() << (num_bits * i)) + return packed + + +def _get_scale_perms() -> tuple[list[int], list[int]]: + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]] + ) + return scale_perm, scale_perm_single + + +def _marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + scale_perm, scale_perm_single = _get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape((-1, size_n)).contiguous() + + +def _marlin_make_workspace(device: torch.device) -> torch.Tensor: + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms, dtype=torch.int, device=device, requires_grad=False) + + +@CUDA_ONLY +def test_gptq_marlin_gemm_matches_reference_for_uint4b8() -> None: + torch.manual_seed(2026) + device = torch.device("cuda") + size_m = 13 + size_n = 64 + size_k = 128 + group_size = 32 + num_bits = 4 + + q_weight = torch.randint( + 0, + 1 << num_bits, + (size_n, size_k), + dtype=torch.int32, + device=device, + ) + scales = ( + torch.rand( + (size_n, size_k // group_size), + dtype=torch.float16, + device=device, + ) + + 0.5 + ) + packed = _pack_checkpoint_weight(q_weight, num_bits=num_bits) + empty = torch.empty(0, dtype=torch.int32, device=device) + marlin_q = gptq_marlin_repack( + packed.t().contiguous(), + perm=empty, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + marlin_s = _marlin_permute_scales( + scales.t().contiguous(), + size_k=size_k, + size_n=size_n, + group_size=group_size, + ) + x = torch.randn((size_m, size_k), dtype=torch.float16, device=device) + workspace = _marlin_make_workspace(device) + + out = gptq_marlin_gemm( + a=x, + c=None, + b_q_weight=marlin_q, + b_scales=marlin_s, + global_scale=None, + b_zeros=empty, + g_idx=empty, + perm=empty, + workspace=workspace, + b_q_type_id=SCALAR_TYPE_UINT4B8_ID, + size_m=size_m, + size_n=size_n, + size_k=size_k, + is_k_full=True, + use_atomic_add=False, + use_fp32_reduce=False, + is_zp_float=False, + ) + + ref_weight = (q_weight.to(torch.float16) - 8) * scales.repeat_interleave( + group_size, dim=1 + ) + ref_out = F.linear(x, ref_weight) + rel_mean_err = torch.mean(torch.abs(out - ref_out)) / torch.mean( + torch.abs(ref_out) + ) + + assert rel_mean_err < 0.04 diff --git a/mllm-kernel/tests/test_gptq_marlin_repack.py b/mllm-kernel/tests/test_gptq_marlin_repack.py new file mode 100644 index 000000000..d9b69f3d8 --- /dev/null +++ b/mllm-kernel/tests/test_gptq_marlin_repack.py @@ -0,0 +1,305 @@ +import pytest +import torch + +from mllm_kernel.cuda.jit import gptq_marlin_repack + + +CUDA_ONLY = pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires CUDA" +) + + +def _pack_rows(q_weight: torch.Tensor, num_bits: int) -> torch.Tensor: + pack_factor = 32 // num_bits + size_k, size_n = q_weight.shape + packed = torch.zeros( + (size_k // pack_factor, size_n), + dtype=torch.int32, + device=q_weight.device, + ) + for i in range(pack_factor): + packed.bitwise_or_(q_weight[i::pack_factor].int() << (num_bits * i)) + return packed + + +def _reference_gptq_marlin_repack_cpu( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + pack_factor = 32 // num_bits + mask = (1 << num_bits) - 1 + q_weight = torch.empty((size_k, size_n), dtype=torch.int32) + for i in range(pack_factor): + q_weight[i::pack_factor] = ( + (b_q_weight >> (num_bits * i)) & mask + )[0 : q_weight[i::pack_factor].shape[0]] + + if perm.numel() == 0: + perm = torch.arange(size_k, dtype=torch.int32) + + out = torch.empty( + (size_k // 16, size_n * 16 // pack_factor), + dtype=torch.int32, + ) + n_tiles = size_n // 64 + tc_offsets = [0, 1, 8, 9] + pack_idx = [0, 2, 4, 6, 1, 3, 5, 7] + tile_size = 16 * 64 // pack_factor + + for k_tile in range(size_k // 16): + for n_tile in range(n_tiles): + tile = torch.empty((16, 64), dtype=torch.int32) + for local_k in range(16): + src_k = int(perm[k_tile * 16 + local_k].item()) + tile[local_k] = q_weight[src_k, n_tile * 64 : (n_tile + 1) * 64] + + flat = torch.empty(tile_size, dtype=torch.int32) + for warp_id in range(4): + for th_id in range(32): + tc_col = th_id // 4 + tc_row = (th_id % 4) * 2 + cur_n = warp_id * 16 + tc_col + + vals = [int(tile[tc_row + off, cur_n].item()) for off in tc_offsets] + vals.extend( + int(tile[tc_row + off, cur_n + 8].item()) + for off in tc_offsets + ) + + res = 0 + for i, src_idx in enumerate(pack_idx): + res |= vals[src_idx] << (i * num_bits) + if res >= 1 << 31: + res -= 1 << 32 + flat[th_id * 4 + warp_id] = res + + out[k_tile, n_tile * tile_size : (n_tile + 1) * tile_size] = flat + + return out + + +@CUDA_ONLY +@pytest.mark.parametrize( + ("size_k", "size_n", "num_bits"), + [(128, 64, 4), (256, 128, 4)], +) +def test_gptq_marlin_repack_outputs_shape(size_k: int, size_n: int, num_bits: int) -> None: + pack_factor = 32 // num_bits + b_q_weight = torch.empty( + (size_k // pack_factor, size_n), + dtype=torch.int32, + device="cuda", + ) + perm = torch.empty(0, dtype=torch.int32, device="cuda") + + out = gptq_marlin_repack( + b_q_weight, + perm, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + + assert out.dtype == torch.int32 + assert out.shape == (size_k // 16, size_n * 16 // pack_factor) + + +@CUDA_ONLY +@pytest.mark.parametrize( + ("size_k", "size_n", "num_bits"), + [(128, 64, 4), (256, 128, 4)], +) +def test_gptq_marlin_repack_accepts_explicit_perm( + size_k: int, + size_n: int, + num_bits: int, +) -> None: + pack_factor = 32 // num_bits + b_q_weight = torch.empty( + (size_k // pack_factor, size_n), + dtype=torch.int32, + device="cuda", + ) + perm = torch.arange(size_k, dtype=torch.int32, device="cuda") + + out1 = gptq_marlin_repack( + b_q_weight, + perm, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + out2 = gptq_marlin_repack( + b_q_weight, + perm, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + + assert torch.equal(out1, out2) + + +@CUDA_ONLY +@pytest.mark.parametrize( + ("size_k", "size_n", "num_bits"), + [(128, 64, 4), (256, 128, 4)], +) +def test_gptq_marlin_repack_identity_perm_matches_empty_perm( + size_k: int, + size_n: int, + num_bits: int, +) -> None: + pack_factor = 32 // num_bits + b_q_weight = torch.empty( + (size_k // pack_factor, size_n), + dtype=torch.int32, + device="cuda", + ) + empty_perm = torch.empty(0, dtype=torch.int32, device="cuda") + perm = torch.arange(size_k, dtype=torch.int32, device="cuda") + + baseline = gptq_marlin_repack( + b_q_weight, + empty_perm, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + with_perm = gptq_marlin_repack( + b_q_weight, + perm, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + + assert torch.equal(baseline, with_perm) + + +@CUDA_ONLY +def test_gptq_marlin_repack_non_identity_perm_matches_reference() -> None: + size_k, size_n, num_bits = 128, 64, 4 + torch.manual_seed(2026) + q_weight = torch.randint( + 0, + 1 << num_bits, + (size_k, size_n), + dtype=torch.int32, + ) + b_q_weight_cpu = _pack_rows(q_weight, num_bits) + perm_cpu = torch.roll(torch.arange(size_k, dtype=torch.int32), 1) + + out = gptq_marlin_repack( + b_q_weight_cpu.to(device="cuda"), + perm_cpu.to(device="cuda"), + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + ref = _reference_gptq_marlin_repack_cpu( + b_q_weight_cpu, + perm_cpu, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + + assert torch.equal(out.cpu(), ref) + + +@CUDA_ONLY +@pytest.mark.parametrize( + ("size_k", "size_n", "num_bits"), + [(128, 64, 4), (256, 128, 4)], +) +def test_gptq_marlin_repack_handles_noncontiguous_perm( + size_k: int, + size_n: int, + num_bits: int, +) -> None: + pack_factor = 32 // num_bits + b_q_weight = torch.empty( + (size_k // pack_factor, size_n), + dtype=torch.int32, + device="cuda", + ) + + buffer = torch.empty( + size_k * 2, + dtype=torch.int32, + device="cuda", + ) + indices = torch.arange(size_k, dtype=torch.int32, device="cuda") + buffer[::2] = indices + buffer[1::2] = indices + perm = buffer.as_strided((size_k,), (2,)) + assert not perm.is_contiguous() + + perm_contig = perm.contiguous() + + out_noncontig = gptq_marlin_repack( + b_q_weight, + perm, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + out_contig = gptq_marlin_repack( + b_q_weight, + perm_contig, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) + + assert torch.equal(out_noncontig, out_contig) + + +@CUDA_ONLY +@pytest.mark.parametrize( + ("size_k", "size_n", "num_bits"), + [(128, 64, 4), (256, 128, 4)], +) +@pytest.mark.parametrize( + "perm_factory", + [ + lambda size_k: torch.arange(size_k, dtype=torch.int32, device="cpu"), + lambda size_k: torch.arange(size_k, dtype=torch.int64, device="cuda"), + lambda size_k: torch.arange(size_k - 16, dtype=torch.int32, device="cuda"), + lambda size_k: torch.full((size_k,), size_k, dtype=torch.int32, device="cuda"), + lambda size_k: torch.full((size_k,), -1, dtype=torch.int32, device="cuda"), + ], + ids=[ + "device-mismatch", + "dtype-mismatch", + "length-mismatch", + "out-of-range", + "negative-index", + ], +) +def test_gptq_marlin_repack_rejects_invalid_perm( + size_k: int, + size_n: int, + num_bits: int, + perm_factory, +) -> None: + pack_factor = 32 // num_bits + b_q_weight = torch.empty( + (size_k // pack_factor, size_n), + dtype=torch.int32, + device="cuda", + ) + perm = perm_factory(size_k) + + with pytest.raises(ValueError): + gptq_marlin_repack( + b_q_weight, + perm, + size_k=size_k, + size_n=size_n, + num_bits=num_bits, + ) diff --git a/mllm-kernel/tests/test_int8_scaled_mm_cutlass.py b/mllm-kernel/tests/test_int8_scaled_mm_cutlass.py new file mode 100644 index 000000000..49e43b5e1 --- /dev/null +++ b/mllm-kernel/tests/test_int8_scaled_mm_cutlass.py @@ -0,0 +1,150 @@ +"""Correctness tests for CUTLASS int8_scaled_mm kernel.""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import torch + + +def _cutlass_source() -> str: + return ( + Path(__file__).resolve().parents[1] + / "mllm_kernel" + / "cuda" + / "csrc" + / "gemm" + / "int8" + / "int8_scaled_mm_cutlass.cu" + ).read_text() + + +def _reference_int8_scaled_mm( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + scales_a: torch.Tensor, + scales_b: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None, +) -> torch.Tensor: + """fp32 reference implementation.""" + out = torch.matmul(mat_a.to(torch.float32), mat_b.to(torch.float32)) + out = out * scales_a.view(-1, 1).float() * scales_b.view(1, -1).float() + if bias is not None: + out = out + bias.float() + return out.to(out_dtype) + + +@pytest.fixture(scope="module") +def cutlass_module(): + """Load CUTLASS module once for all tests.""" + pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import int8_scaled_mm + return int8_scaled_mm + + +def test_cutlass_wrapper_rejects_unsupported_out_dtype(monkeypatch): + from mllm_kernel.cuda.jit import int8_scaled_mm_cutlass as cutlass_wrapper + + class FakeModule: + def int8_scaled_mm(self, *args, **kwargs): + return torch.empty((1, 8), dtype=torch.bfloat16) + + monkeypatch.setattr(cutlass_wrapper, "_load_module", lambda: FakeModule()) + + mat_a = torch.empty((1, 16), dtype=torch.int8) + mat_b = torch.empty((16, 8), dtype=torch.int8) + scales_a = torch.empty((1,), dtype=torch.float32) + scales_b = torch.empty((8,), dtype=torch.float32) + + with pytest.raises(ValueError, match="out_dtype"): + cutlass_wrapper.int8_scaled_mm( + mat_a, mat_b, scales_a, scales_b, torch.float32, + ) + + +def test_cutlass_jit_uses_current_gpu_arch_for_compile(monkeypatch): + import torch.utils.cpp_extension as cpp_extension + + from mllm_kernel.cuda.jit import int8_scaled_mm_cutlass as cutlass_wrapper + + calls = {} + + class FakeLoadedModule: + pass + + def fake_load(**kwargs): + calls.update(kwargs) + return FakeLoadedModule() + + monkeypatch.setattr(cutlass_wrapper, "_module", None) + monkeypatch.setattr(cutlass_wrapper, "_module_arch", None, raising=False) + monkeypatch.setattr(cutlass_wrapper, "_CUTLASS_INC", None) + monkeypatch.setattr( + cutlass_wrapper, + "_find_cutlass_include", + lambda: "/tmp/cutlass/include", + ) + monkeypatch.setattr( + cutlass_wrapper.torch.cuda, + "get_device_capability", + lambda: (8, 9), + ) + monkeypatch.setattr(cpp_extension, "load", fake_load) + + cutlass_wrapper._load_module() + + assert "-arch=sm_89" in calls["extra_cuda_cflags"] + assert calls["name"].endswith("_sm_89") + assert calls["build_directory"].endswith("sm_89") + + +def test_cutlass_dispatch_keeps_sglang_sm80_sm89_split(): + source = _cutlass_source() + + assert "if (sm_version == 86 || sm_version == 89)" in source + assert "sm89_dispatch_shape= 3 + + +@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("with_bias", [False, True]) +@pytest.mark.parametrize( + "M,N,K", + [ + (1, 64, 32), + (1, 2048, 2048), + (8, 128, 64), + (16, 6144, 2048), + (32, 2048, 2048), + (93, 6144, 2048), + (128, 2048, 6144), + ], +) +def test_cutlass_matches_reference( + cutlass_module, M, N, K, out_dtype, with_bias, +): + torch.manual_seed(42) + mat_a = torch.randint(-127, 128, (M, K), dtype=torch.int8, device="cuda") + mat_b = torch.randint(-127, 128, (K, N), dtype=torch.int8, device="cuda") + # Make col-major B + mat_b_col = mat_b.t().contiguous().t() + + scales_a = (torch.rand(M, dtype=torch.float32, device="cuda") + 0.01) * 0.01 + scales_b = (torch.rand(N, dtype=torch.float32, device="cuda") + 0.01) * 0.01 + bias = torch.randn(N, dtype=out_dtype, device="cuda") * 0.01 if with_bias else None + + out = cutlass_module(mat_a, mat_b_col, scales_a, scales_b, out_dtype, bias) + ref = _reference_int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias) + + torch.testing.assert_close(out, ref, atol=0.1, rtol=0.05) diff --git a/pymllm/README-ZH.md b/pymllm/README-ZH.md new file mode 100644 index 000000000..a32c6580c --- /dev/null +++ b/pymllm/README-ZH.md @@ -0,0 +1,254 @@ +# pymllm + +![pymllm-arch](../assets/pymllm-arch.png) + +`pymllm` 是 `mllm` 的 Python 推理服务入口。本目录当前重点覆盖 +Jetson Orin 上的 Qwen3 / Qwen3-VL 推理、OpenAI-compatible server、 +`compressed-tensors` 量化加载,以及 W8A8 INT8 kernel 路径。 + +本文档按 2026-04-27 的开发状态整理,适用于当前集成分支: + +```text +feature/jetson-qwen3-family-bf16-w4a16-w8a8 +``` + +## 当前状态 + +已验证路径: + +- `Qwen3-VL-2B-Instruct`:BF16 原生模型服务可用。 +- `Qwen3-VL-2B-Instruct-AWQ-4bit`:`compressed-tensors` + W4A16 / AWQ Marlin 路径可用。 +- `Qwen3-VL-2B-Instruct-quantized.w8a8`:`compressed-tensors` + W8A8 `int-quantized` 路径端到端可用。 + +已实现并纳入单元测试的模型/组件: + +- `Qwen3VLForConditionalGeneration`:图文模型服务主路径。 +- `Qwen3ForCausalLM`:文本模型骨架、权重加载与 timing 字段测试。 +- `compressed-tensors`: + - `pack-quantized` 4-bit 权重路径,使用 GPTQ Marlin。 + - `int-quantized` W8A8 路径,使用 Triton 激活量化 + CUTLASS + `int8_scaled_mm`。 + +W8A8 当前前向链路: + +```text +x(fp16/bf16) + -> per_token_quant_int8 [Triton, dynamic per-token activation quant] + -> int8_scaled_mm [CUTLASS, INT8 Tensor Core, fused scales] + -> output(fp16/bf16) +``` + +## 已验证环境 + +以下命令基于 Jetson Orin 环境整理: + +- JetPack / L4T:`R36.4.4`(来自 `/etc/nv_tegra_release`) +- Python:`3.10.12` +- PyTorch:`2.4.0` +- torchvision:`0.19.0a0+48b1edf` +- transformers:`5.3.0` +- safetensors:`0.7.0` +- flashinfer:`0.6.7` +- Triton Language:官方 PyPI `triton==3.6.0` manylinux aarch64 wheel +- CUDA:`12.6` +- GPU:Jetson Orin NX,SM87 + +这里的 Triton 指 GPU kernel DSL,不是 Triton Inference Server。Jetson-AI-Lab +源也提供 `3.4.0`、`3.5.1`、`3.6.0`,但实测中可能需要额外设置 +`TRITON_PTXAS_PATH` 和 `CPATH`。当前建议优先使用官方 PyPI 的 +`triton==3.6.0`,并用最小 CUDA kernel 或 `per_token_quant_int8` 做 smoke test。 + +W8A8 CUTLASS JIT 需要能找到 CUTLASS 头文件。当前查找顺序为: + +1. `CUTLASS_HOME/include` +2. `flashinfer` 内置的 `data/cutlass/include` +3. `/usr/local/include`、`/usr/include`、`/usr/local/cuda/include` + +首次调用 CUTLASS kernel 会触发 JIT 编译,耗时约 100 秒;后续会复用: + +```text +~/.cache/mllm_kernel/cutlass_int8_scaled_mm/ +``` + +## 安装开发环境 + +在仓库根目录执行: + +```bash +cd +SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e . +python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation +``` + +最小导入检查: + +```bash +python3 - <<'PY' +import pymllm +import mllm_kernel + +print("pymllm import ok") +print("mllm_kernel import ok") +PY +``` + +## 启动服务 + +### 量化模型(W4A16 / W8A8) + +```bash +cd + +python3 -m pymllm.server.launch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype float16 \ + --quantization.method compressed-tensors \ + --server.host 0.0.0.0 \ + --server.port 30000 \ + --server.attention_backend auto \ + --server.gdn_decode_backend pytorch \ + --server.mem_fraction_static 0.05 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 256 \ + --server.max_prefill_tokens 128 \ + --server.chunked_prefill_size 128 \ + --server.disable_radix_cache \ + --server.disable_cuda_graph \ + --server.log_level debug +``` + +说明: + +- `--quantization.method compressed-tensors` 会按模型 `config.json` + 自动识别 W4A16 或 W8A8 签名。 +- W8A8 路径要求 GPU capability 不低于 SM80。 +- `--server.disable_radix_cache` 会使用 `ChunkCache`,当前已修复该模式下的 + KV slot 泄漏问题。 +- 若 `30000` 已被占用,可改成其他空闲端口。 + +### BF16 原生模型 + +```bash +cd + +python3 -m pymllm.server.launch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype float16 \ + --server.host 0.0.0.0 \ + --server.port 30000 \ + --server.attention_backend auto \ + --server.gdn_decode_backend pytorch \ + --server.mem_fraction_static 0.05 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 256 \ + --server.max_prefill_tokens 128 \ + --server.chunked_prefill_size 128 \ + --server.disable_radix_cache \ + --server.disable_cuda_graph \ + --server.log_level debug +``` + +## 调用示例 + +### 健康检查 + +```bash +curl -s --noproxy '*' http://127.0.0.1:30000/v1/models ; echo +``` + +期望返回中包含: + +```text +"owned_by":"pymllm" +``` + +### 文本请求 + +```bash +curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "None", + "messages": [{"role": "user", "content": "你好,只回复:ok"}], + "max_tokens": 8, + "temperature": 0.0, + "stream": false + }' ; echo +``` + +### 图文请求 + +图片路径请使用容器内可访问的绝对路径,不要使用 `file://...` 前缀。 + +```bash +python3 - <<'PY' +import json + +payload = { + "model": "None", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "请详细描述这张图片。"}, + {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}}, + ], + } + ], + "max_tokens": 128, + "temperature": 0.0, + "stream": False, +} + +with open("/tmp/mm_req_path.json", "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False) + +print("saved /tmp/mm_req_path.json") +PY + +curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + --data @/tmp/mm_req_path.json ; echo +``` + +## 开发与测试 + +常用单元测试: + +```bash +pytest pymllm/tests/test_compressed_tensors_config.py -q +pytest pymllm/tests/test_compressed_tensors_runtime.py -q +pytest pymllm/tests/test_qwen3_model_registry.py -q +pytest pymllm/tests/test_qwen3_weight_loading.py -q +pytest pymllm/tests/test_qwen3_forward_timing.py -q +pytest mllm-kernel/tests/test_int8_scaled_mm_cutlass.py -q +``` + +常用 microbench: + +```bash +python3 pymllm/tests/bench_w8a8_activation_quant.py +python3 mllm-kernel/benchmarks/bench_int8_scaled_mm.py +python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py +``` + +如果需要重新测 CUTLASS 首次编译,可先清理 JIT 缓存: + +```bash +rm -rf ~/.cache/mllm_kernel/cutlass_int8_scaled_mm/ +``` + +## 已知限制 + +- W8A8 CUTLASS 当前通过 JIT 编译,首次启动存在约 100 秒编译开销。 +- W8A8 激活量化使用 Triton kernel;decode 下固定量化开销仍是后续优化点。 +- Qwen3-VL 的 ViT、`lm_head`、embedding 和 LayerNorm 不在当前 W8A8 量化范围内。 +- 其他 GPU 需要重新验证 tile dispatch、JIT 编译和性能。 +- 为对齐 SGLang/OpenAI 兼容响应,OpenAI API 默认不返回 debug timing。 + 仅在本地诊断时使用 `--server.enable_debug_timing`;严格模型级计时应使用专用 benchmark。 diff --git a/pymllm/README.md b/pymllm/README.md index bee5ac41c..439f74bc7 100644 --- a/pymllm/README.md +++ b/pymllm/README.md @@ -1,3 +1,265 @@ # pymllm ![pymllm-arch](../assets/pymllm-arch.png) + +`pymllm` is the Python inference and serving entry point for `mllm`. This +directory currently focuses on Qwen3 / Qwen3-VL serving on Jetson Orin, +OpenAI-compatible APIs, `compressed-tensors` quantized loading, and the W8A8 +INT8 kernel path. + +This README reflects the development state as of 2026-04-27 for the integration +branch: + +```text +feature/jetson-qwen3-family-bf16-w4a16-w8a8 +``` + +## Current status + +Validated paths: + +- `Qwen3-VL-2B-Instruct`: BF16 base-model serving. +- `Qwen3-VL-2B-Instruct-AWQ-4bit`: `compressed-tensors` W4A16 / AWQ Marlin + serving. +- `Qwen3-VL-2B-Instruct-quantized.w8a8`: `compressed-tensors` W8A8 + `int-quantized` end-to-end serving. + +Implemented and unit-tested models/components: + +- `Qwen3VLForConditionalGeneration`: the main multimodal serving path. +- `Qwen3ForCausalLM`: text-only model skeleton, weight loading, and timing + tests. +- `compressed-tensors`: + - `pack-quantized` 4-bit weight path via GPTQ Marlin. + - `int-quantized` W8A8 path via Triton activation quantization and CUTLASS + `int8_scaled_mm`. + +The current W8A8 forward path is: + +```text +x(fp16/bf16) + -> per_token_quant_int8 [Triton, dynamic per-token activation quant] + -> int8_scaled_mm [CUTLASS, INT8 Tensor Core, fused scales] + -> output(fp16/bf16) +``` + +## Validated environment + +The commands below were validated on Jetson Orin with: + +- JetPack / L4T: `R36.4.4` (`/etc/nv_tegra_release`) +- Python: `3.10.12` +- PyTorch: `2.4.0` +- torchvision: `0.19.0a0+48b1edf` +- transformers: `5.3.0` +- safetensors: `0.7.0` +- flashinfer: `0.6.7` +- Triton Language: official PyPI `triton==3.6.0` manylinux aarch64 wheel +- CUDA: `12.6` +- GPU: Jetson Orin NX, SM87 + +Triton here means the GPU kernel DSL, not Triton Inference Server. The +Jetson-AI-Lab index also provides `3.4.0`, `3.5.1`, and `3.6.0`, but the tested +environment may require extra `TRITON_PTXAS_PATH` and `CPATH` settings with +those wheels. For this project, prefer the official PyPI `triton==3.6.0` wheel +and verify it with a minimal CUDA kernel or `per_token_quant_int8` smoke test. + +The W8A8 CUTLASS JIT path requires CUTLASS headers. The lookup order is: + +1. `CUTLASS_HOME/include` +2. `flashinfer` bundled `data/cutlass/include` +3. `/usr/local/include`, `/usr/include`, `/usr/local/cuda/include` + +The first CUTLASS kernel call triggers JIT compilation and may take about +100 seconds. Later runs reuse: + +```text +~/.cache/mllm_kernel/cutlass_int8_scaled_mm/ +``` + +## Install the development environment + +Run from the repository root: + +```bash +cd +SKBUILD_WHEEL_CMAKE=false python3 -m pip install -e . +python3 -m pip install -e /mllm-kernel --no-deps --no-build-isolation +``` + +Run a minimal import check: + +```bash +python3 - <<'PY' +import pymllm +import mllm_kernel + +print("pymllm import ok") +print("mllm_kernel import ok") +PY +``` + +## Launch the server + +### Quantized models (W4A16 / W8A8) + +```bash +cd + +python3 -m pymllm.server.launch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype float16 \ + --quantization.method compressed-tensors \ + --server.host 0.0.0.0 \ + --server.port 30000 \ + --server.attention_backend auto \ + --server.gdn_decode_backend pytorch \ + --server.mem_fraction_static 0.05 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 256 \ + --server.max_prefill_tokens 128 \ + --server.chunked_prefill_size 128 \ + --server.disable_radix_cache \ + --server.disable_cuda_graph \ + --server.log_level debug +``` + +Notes: + +- `--quantization.method compressed-tensors` reads the model `config.json` and + selects the W4A16 or W8A8 signature automatically. +- W8A8 requires SM80 or newer GPUs. +- `--server.disable_radix_cache` uses `ChunkCache`; the KV slot leak in this + mode has been fixed. +- If port `30000` is already in use, switch to another free port. + +### BF16 base models + +```bash +cd + +python3 -m pymllm.server.launch \ + --server.model_path \ + --server.tokenizer_path \ + --server.load_format safetensors \ + --server.dtype float16 \ + --server.host 0.0.0.0 \ + --server.port 30000 \ + --server.attention_backend auto \ + --server.gdn_decode_backend pytorch \ + --server.mem_fraction_static 0.05 \ + --server.max_running_requests 1 \ + --server.max_total_tokens 256 \ + --server.max_prefill_tokens 128 \ + --server.chunked_prefill_size 128 \ + --server.disable_radix_cache \ + --server.disable_cuda_graph \ + --server.log_level debug +``` + +## Request examples + +### Health check + +```bash +curl -s --noproxy '*' http://127.0.0.1:30000/v1/models ; echo +``` + +Expected response contains: + +```text +"owned_by":"pymllm" +``` + +### Text request + +```bash +curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "None", + "messages": [{"role": "user", "content": "Reply with: ok"}], + "max_tokens": 8, + "temperature": 0.0, + "stream": false + }' ; echo +``` + +### Image request + +Use a container-visible absolute image path. Do not use the `file://...` +prefix. + +```bash +python3 - <<'PY' +import json + +payload = { + "model": "None", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please describe this image in detail."}, + {"type": "image_url", "image_url": {"url": "/workspace/xcd_mllm/test.png"}}, + ], + } + ], + "max_tokens": 128, + "temperature": 0.0, + "stream": False, +} + +with open("/tmp/mm_req_path.json", "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False) + +print("saved /tmp/mm_req_path.json") +PY + +curl -s --noproxy '*' http://127.0.0.1:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + --data @/tmp/mm_req_path.json ; echo +``` + +## Development and tests + +Common unit tests: + +```bash +pytest pymllm/tests/test_compressed_tensors_config.py -q +pytest pymllm/tests/test_compressed_tensors_runtime.py -q +pytest pymllm/tests/test_qwen3_model_registry.py -q +pytest pymllm/tests/test_qwen3_weight_loading.py -q +pytest pymllm/tests/test_qwen3_forward_timing.py -q +pytest mllm-kernel/tests/test_int8_scaled_mm_cutlass.py -q +``` + +Common microbenchmarks: + +```bash +python3 pymllm/tests/bench_w8a8_activation_quant.py +python3 mllm-kernel/benchmarks/bench_int8_scaled_mm.py +python3 mllm-kernel/benchmarks/bench_w4a16_vs_w8a8.py +``` + +To measure first-use CUTLASS compilation again, clear the JIT cache: + +```bash +rm -rf ~/.cache/mllm_kernel/cutlass_int8_scaled_mm/ +``` + +## Known limitations + +- The W8A8 CUTLASS path is JIT-compiled, so first startup includes about + 100 seconds of compilation overhead. +- W8A8 activation quantization uses a Triton kernel; its fixed decode-time + cost remains a future optimization target. +- Qwen3-VL ViT, `lm_head`, embeddings, and LayerNorm are outside the current + W8A8 quantized scope. +- Other GPUs need separate validation for tile dispatch, JIT compilation, and + performance. +- OpenAI-compatible responses hide debug timing by default for SGLang/OpenAI + compatibility. Use `--server.enable_debug_timing` only for local diagnostics; + strict model-level timing should use dedicated benchmarks. diff --git a/pymllm/bench_one_batch.py b/pymllm/bench_one_batch.py new file mode 100644 index 000000000..a62be2bb2 --- /dev/null +++ b/pymllm/bench_one_batch.py @@ -0,0 +1,691 @@ +"""SGLang-style one-batch benchmark for pymllm. + +This module intentionally bypasses the HTTP server, tokenizer workers, +scheduler, and detokenizer. It drives :class:`pymllm.executor.ModelRunner` +directly to measure one static prefill followed by token-by-token decode. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import re +import statistics +import time +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterator, Optional, Sequence + +import torch + +from pymllm.configs.global_config import GlobalConfig, make_args, read_args +from pymllm.executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BenchSetting: + batch_size: int + input_len: int + output_len: int + + +@dataclass +class BenchArgs: + run_name: str = "default" + batch_size: list[int] = field(default_factory=lambda: [1]) + input_len: list[int] = field(default_factory=lambda: [256, 512, 1024]) + output_len: list[int] = field(default_factory=lambda: [128]) + result_filename: Path = Path("/tmp/pymllm_bench_one_batch.jsonl") + log_decode_step: int = 0 + seed: int = 42 + profile: bool = False + profile_record_shapes: bool = False + profile_activities: list[str] = field(default_factory=lambda: ["CPU", "GPU"]) + profile_stage: str = "all" + profile_filename_prefix: str = "pymllm_profile" + profile_start_step: Optional[int] = None + profile_steps: int = 1 + skip_warmup: bool = False + + +@dataclass +class DecodeState: + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + mrope_position_deltas: Optional[torch.Tensor] = None + + +def _positive_int(value: str) -> int: + parsed = int(value) + if parsed <= 0: + raise argparse.ArgumentTypeError(f"Expected a positive integer, got {value!r}") + return parsed + + +def _non_negative_int(value: str) -> int: + parsed = int(value) + if parsed < 0: + raise argparse.ArgumentTypeError( + f"Expected a non-negative integer, got {value!r}" + ) + return parsed + + +def add_bench_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + "bench_one_batch", + "Options for the low-level one-batch benchmark.", + ) + group.add_argument("--run-name", default=BenchArgs.run_name) + group.add_argument( + "--batch-size", + nargs="+", + type=_positive_int, + default=[1], + help="Batch sizes to sweep.", + ) + group.add_argument( + "--input-len", + nargs="+", + type=_positive_int, + default=[256, 512, 1024], + help="Prefill/input lengths to sweep.", + ) + group.add_argument( + "--output-len", + nargs="+", + type=_positive_int, + default=[128], + help="Output lengths to sweep. Matches SGLang's total output token semantics.", + ) + group.add_argument( + "--result-filename", + type=Path, + default=BenchArgs.result_filename, + help="JSONL result file. Rows are appended.", + ) + group.add_argument( + "--log-decode-step", + type=_non_negative_int, + default=0, + help="Log every N decode steps. 0 disables per-step logging.", + ) + group.add_argument("--seed", type=int, default=42) + group.add_argument("--profile", action="store_true") + group.add_argument("--profile-record-shapes", action="store_true") + group.add_argument( + "--profile-activities", + nargs="+", + choices=["CPU", "GPU"], + default=["CPU", "GPU"], + ) + group.add_argument( + "--profile-stage", + choices=["all", "prefill", "decode"], + default="all", + ) + group.add_argument( + "--profile-filename-prefix", + default=BenchArgs.profile_filename_prefix, + ) + group.add_argument( + "--profile-start-step", + type=_non_negative_int, + default=None, + help="Decode step index where profiling starts. Defaults to the middle step.", + ) + group.add_argument( + "--profile-steps", + type=_positive_int, + default=1, + help="Number of decode steps to profile.", + ) + group.add_argument( + "--skip-warmup", + action="store_true", + help="Skip the initial non-recorded warmup run.", + ) + return parser + + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="python3 -m pymllm.bench_one_batch", + description="Run a SGLang-style direct ModelRunner one-batch benchmark.", + ) + make_args(parser) + add_bench_args(parser) + return parser + + +def _bench_args_from_namespace(namespace: argparse.Namespace) -> BenchArgs: + return BenchArgs( + run_name=namespace.run_name, + batch_size=list(namespace.batch_size), + input_len=list(namespace.input_len), + output_len=list(namespace.output_len), + result_filename=Path(namespace.result_filename), + log_decode_step=namespace.log_decode_step, + seed=namespace.seed, + profile=namespace.profile, + profile_record_shapes=namespace.profile_record_shapes, + profile_activities=list(namespace.profile_activities), + profile_stage=namespace.profile_stage, + profile_filename_prefix=namespace.profile_filename_prefix, + profile_start_step=namespace.profile_start_step, + profile_steps=namespace.profile_steps, + skip_warmup=namespace.skip_warmup, + ) + + +def parse_args( + argv: Optional[Sequence[str]] = None, +) -> tuple[GlobalConfig, BenchArgs]: + parser = make_parser() + cfg = read_args(argv=argv, parser=parser) + namespace = parser.parse_args(argv) + return cfg, _bench_args_from_namespace(namespace) + + +def generate_settings(args: BenchArgs) -> list[BenchSetting]: + return [ + BenchSetting(batch_size=batch_size, input_len=input_len, output_len=output_len) + for batch_size in args.batch_size + for input_len in args.input_len + for output_len in args.output_len + ] + + +def make_synthetic_input_ids( + *, + batch_size: int, + input_len: int, + vocab_size: int, + seed: int, + device: str | torch.device, +) -> torch.Tensor: + upper = max(1, min(int(vocab_size or 10000), 10000)) + generator = torch.Generator(device="cpu") + generator.manual_seed(seed) + input_ids = torch.randint( + low=0, + high=upper, + size=(batch_size, input_len), + generator=generator, + dtype=torch.int32, + device="cpu", + ) + return input_ids.to(device=device) + + +def summarize_latencies( + *, + setting: BenchSetting, + prefill_latency: float, + decode_latencies: Sequence[float], + run_name: str, + device: str, + dtype: str, + cuda_graph: bool, + extra: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + median_decode_latency = ( + float(statistics.median(decode_latencies)) if decode_latencies else 0.0 + ) + total_latency = float(prefill_latency + sum(decode_latencies)) + result: dict[str, Any] = { + "run_name": run_name, + "batch_size": setting.batch_size, + "input_len": setting.input_len, + "output_len": setting.output_len, + "prefill_latency": float(prefill_latency), + "prefill_throughput": _safe_div( + setting.batch_size * setting.input_len, + prefill_latency, + ), + "median_decode_latency": median_decode_latency, + "median_decode_throughput": _safe_div( + setting.batch_size, + median_decode_latency, + ), + "total_latency": total_latency, + "overall_throughput": _safe_div( + setting.batch_size * (setting.input_len + setting.output_len), + total_latency, + ), + "device": device, + "dtype": dtype, + "cuda_graph": cuda_graph, + } + if extra: + result.update(extra) + return result + + +def make_profile_trace_path( + *, + output_dir: Path, + prefix: str, + run_name: str, + setting: BenchSetting, + stage: str, + step: Optional[int] = None, +) -> Path: + safe_run_name = _sanitize_filename_part(run_name) + safe_prefix = _sanitize_filename_part(prefix) + step_part = f"_step{step}" if step is not None else "" + filename = ( + f"{safe_prefix}_{safe_run_name}_bs{setting.batch_size}" + f"_in{setting.input_len}_out{setting.output_len}_{stage}" + f"{step_part}.trace.json" + ) + return output_dir / filename + + +def _sanitize_filename_part(value: str) -> str: + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") + return sanitized or "default" + + +def _safe_div(numerator: float, denominator: float) -> float: + if denominator <= 0: + return 0.0 + return float(numerator / denominator) + + +def _sync_device(device: str | torch.device) -> None: + torch_device = torch.device(device) + if torch_device.type == "cuda": + torch.cuda.synchronize(torch_device) + + +def _configure_logging(level_name: str) -> None: + level = getattr(logging, level_name.upper(), logging.INFO) + root_logger = logging.getLogger() + if not root_logger.handlers: + logging.basicConfig( + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + else: + root_logger.setLevel(level) + logging.getLogger("pymllm").setLevel(level) + + +def _load_hf_config(cfg: GlobalConfig) -> None: + if cfg.server.model_path is None: + raise ValueError("--server.model_path is required") + + from transformers import AutoConfig + + cfg.model.hf_config = AutoConfig.from_pretrained( + str(cfg.server.model_path), + trust_remote_code=cfg.server.trust_remote_code, + ) + logger.info("Loaded model config: %s", cfg.model.hf_config.__class__.__name__) + + +def _append_jsonl(path: Path, row: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as fp: + fp.write(json.dumps(row, sort_keys=True) + "\n") + + +def _profile_stage_enabled(args: BenchArgs, stage: str) -> bool: + return args.profile and args.profile_stage in ("all", stage) + + +def _profiler_activities(args: BenchArgs) -> list[Any]: + from torch.profiler import ProfilerActivity + + activities = [] + if "CPU" in args.profile_activities: + activities.append(ProfilerActivity.CPU) + if "GPU" in args.profile_activities: + if torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + else: + logger.warning("GPU profiling requested but CUDA is not available.") + return activities + + +@contextmanager +def _maybe_profile( + *, + args: BenchArgs, + setting: BenchSetting, + stage: str, + step: Optional[int] = None, +) -> Iterator[None]: + if not _profile_stage_enabled(args, stage): + with nullcontext(): + yield + return + + activities = _profiler_activities(args) + if not activities: + with nullcontext(): + yield + return + + from torch.profiler import profile + + output_dir = Path(os.environ.get("PYMLLM_TORCH_PROFILER_DIR", "/tmp")) + output_dir.mkdir(parents=True, exist_ok=True) + trace_path = make_profile_trace_path( + output_dir=output_dir, + prefix=args.profile_filename_prefix, + run_name=args.run_name, + setting=setting, + stage=stage, + step=step, + ) + with profile( + activities=activities, + record_shapes=args.profile_record_shapes, + ) as profiler: + yield + profiler.step() + profiler.export_chrome_trace(str(trace_path)) + logger.info("Wrote torch profiler trace: %s", trace_path) + + +class PymllmBenchRunner: + def __init__(self, runner: ModelRunner): + self.runner = runner + self.device = runner.device + + @classmethod + def create(cls, cfg: GlobalConfig) -> "PymllmBenchRunner": + runner = ModelRunner( + server_config=cfg.server, + model_config=cfg.model, + gpu_id=cfg.server.base_gpu_id, + ) + runner.initialize() + return cls(runner) + + def clear(self) -> None: + if self.runner.req_to_token_pool is None: + raise RuntimeError("ModelRunner req_to_token_pool is not initialized") + if self.runner.token_to_kv_pool_allocator is None: + raise RuntimeError( + "ModelRunner token_to_kv_pool_allocator is not initialized" + ) + self.runner.req_to_token_pool.clear() + self.runner.token_to_kv_pool_allocator.clear() + + def extend(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, DecodeState]: + if input_ids.dim() != 2: + raise ValueError("input_ids must have shape [batch_size, input_len]") + + self._require_initialized() + batch_size, input_len = input_ids.shape + req_slots = self.runner.req_to_token_pool.alloc(batch_size) + if req_slots is None: + raise RuntimeError(f"Failed to allocate {batch_size} request slots") + + total_tokens = batch_size * input_len + out_cache_loc = self.runner.token_to_kv_pool_allocator.alloc(total_tokens) + if out_cache_loc is None: + for slot in req_slots: + self.runner.req_to_token_pool.free(slot) + raise RuntimeError(f"Failed to allocate {total_tokens} KV slots") + + offset = 0 + for slot in req_slots: + self.runner.req_to_token_pool.write( + (slot, slice(0, input_len)), + out_cache_loc[offset : offset + input_len], + ) + offset += input_len + + req_pool_indices = torch.tensor( + req_slots, dtype=torch.int64, device=self.device + ) + if self.runner.gdn_pool is not None: + self.runner.gdn_pool.reset_states(req_pool_indices) + + seq_lens = torch.full( + (batch_size,), + input_len, + dtype=torch.int32, + device=self.device, + ) + extend_seq_lens = torch.full_like(seq_lens, input_len) + extend_prefix_lens = torch.zeros_like(seq_lens) + + forward_batch = self.runner.prepare_forward_batch_extend( + input_ids=input_ids.reshape(-1).to(device=self.device, dtype=torch.int32), + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_seq_lens=extend_seq_lens, + extend_prefix_lens=extend_prefix_lens, + out_cache_loc=out_cache_loc.to(torch.int64), + ) + logits_output = self.runner.forward(forward_batch) + next_token_ids = self._sample_greedy(logits_output, forward_batch) + state = DecodeState( + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + mrope_position_deltas=getattr( + forward_batch, "mrope_position_deltas", None + ), + ) + return next_token_ids, state + + def decode( + self, + input_ids: torch.Tensor, + state: DecodeState, + ) -> tuple[torch.Tensor, DecodeState]: + self._require_initialized() + batch_size = int(state.req_pool_indices.shape[0]) + if input_ids.shape != (batch_size,): + raise ValueError( + f"decode input_ids must have shape ({batch_size},), got {tuple(input_ids.shape)}" + ) + + out_cache_loc = self.runner.token_to_kv_pool_allocator.alloc(batch_size) + if out_cache_loc is None: + raise RuntimeError(f"Failed to allocate {batch_size} decode KV slots") + + seq_lens = state.seq_lens + 1 + for i in range(batch_size): + slot = int(state.req_pool_indices[i].item()) + write_pos = int(seq_lens[i].item()) - 1 + self.runner.req_to_token_pool.write( + (slot, slice(write_pos, write_pos + 1)), + out_cache_loc[i : i + 1], + ) + + forward_batch = self.runner.prepare_forward_batch_decode( + input_ids=input_ids.to(device=self.device, dtype=torch.int32), + req_pool_indices=state.req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc.to(torch.int64), + mrope_position_deltas=state.mrope_position_deltas, + ) + logits_output = self.runner.forward(forward_batch) + next_token_ids = self._sample_greedy(logits_output, forward_batch) + return next_token_ids, DecodeState( + req_pool_indices=state.req_pool_indices, + seq_lens=seq_lens, + mrope_position_deltas=state.mrope_position_deltas, + ) + + def shutdown(self) -> None: + self.runner.shutdown() + + def _sample_greedy(self, logits_output: Any, forward_batch: Any) -> torch.Tensor: + temperatures = torch.zeros( + (forward_batch.batch_size,), + dtype=torch.float32, + device=self.device, + ) + return self.runner.sample( + logits_output, + forward_batch, + temperatures=temperatures, + ).to(torch.int32) + + def _require_initialized(self) -> None: + if self.runner.req_to_token_pool is None: + raise RuntimeError("ModelRunner req_to_token_pool is not initialized") + if self.runner.token_to_kv_pool_allocator is None: + raise RuntimeError( + "ModelRunner token_to_kv_pool_allocator is not initialized" + ) + + +def _timed_call( + device: str | torch.device, + fn: Any, +) -> tuple[float, Any]: + _sync_device(device) + tic = time.perf_counter() + result = fn() + _sync_device(device) + return time.perf_counter() - tic, result + + +def run_single_setting( + *, + bench_runner: PymllmBenchRunner, + args: BenchArgs, + setting: BenchSetting, + seed: int, + record_result: bool, +) -> Optional[dict[str, Any]]: + bench_runner.clear() + vocab_size = getattr(bench_runner.runner, "vocab_size", 10000) + input_ids = make_synthetic_input_ids( + batch_size=setting.batch_size, + input_len=setting.input_len, + vocab_size=vocab_size, + seed=seed, + device=bench_runner.device, + ) + + with _maybe_profile(args=args, setting=setting, stage="prefill"): + prefill_latency, extend_result = _timed_call( + bench_runner.device, + lambda: bench_runner.extend(input_ids), + ) + next_token_ids, state = extend_result + + decode_latencies: list[float] = [] + decode_steps = max(0, setting.output_len - 1) + profile_start_step = args.profile_start_step + if profile_start_step is None: + profile_start_step = decode_steps // 2 if decode_steps else 0 + profile_stop_step = profile_start_step + args.profile_steps + + for step in range(decode_steps): + should_profile_decode = ( + _profile_stage_enabled(args, "decode") + and profile_start_step <= step < profile_stop_step + ) + profile_context = ( + _maybe_profile(args=args, setting=setting, stage="decode", step=step) + if should_profile_decode + else nullcontext() + ) + with profile_context: + decode_latency, decode_result = _timed_call( + bench_runner.device, + lambda: bench_runner.decode(next_token_ids, state), + ) + next_token_ids, state = decode_result + decode_latencies.append(decode_latency) + + if args.log_decode_step and (step + 1) % args.log_decode_step == 0: + logger.info( + "decode step %d/%d: %.6f s", + step + 1, + decode_steps, + decode_latency, + ) + + if not record_result: + return None + + return summarize_latencies( + setting=setting, + prefill_latency=prefill_latency, + decode_latencies=decode_latencies, + run_name=args.run_name, + device=bench_runner.device, + dtype=str(bench_runner.runner.dtype), + cuda_graph=bench_runner.runner.graph_runner is not None, + ) + + +def run_benchmark(cfg: GlobalConfig, args: BenchArgs) -> list[dict[str, Any]]: + _load_hf_config(cfg) + logger.info( + "bench_one_batch bypasses scheduler; max_prefill_tokens/chunked_prefill_size " + "do not chunk this benchmark." + ) + + bench_runner = PymllmBenchRunner.create(cfg) + try: + settings = generate_settings(args) + if not args.skip_warmup and settings: + first = settings[0] + warmup_setting = BenchSetting( + batch_size=first.batch_size, + input_len=first.input_len, + output_len=min(32, first.output_len), + ) + logger.info( + "Warmup: batch_size=%d input_len=%d output_len=%d", + warmup_setting.batch_size, + warmup_setting.input_len, + warmup_setting.output_len, + ) + run_single_setting( + bench_runner=bench_runner, + args=args, + setting=warmup_setting, + seed=args.seed, + record_result=False, + ) + + results: list[dict[str, Any]] = [] + for index, setting in enumerate(settings): + logger.info( + "Benchmark: batch_size=%d input_len=%d output_len=%d", + setting.batch_size, + setting.input_len, + setting.output_len, + ) + result = run_single_setting( + bench_runner=bench_runner, + args=args, + setting=setting, + seed=args.seed + index, + record_result=True, + ) + assert result is not None + _append_jsonl(args.result_filename, result) + logger.info("Result: %s", json.dumps(result, sort_keys=True)) + results.append(result) + return results + finally: + bench_runner.shutdown() + + +def main(argv: Optional[Sequence[str]] = None) -> None: + cfg, args = parse_args(argv) + _configure_logging(cfg.server.log_level) + run_benchmark(cfg, args) + + +if __name__ == "__main__": + main() diff --git a/pymllm/configs/server_config.py b/pymllm/configs/server_config.py index 92d02e05e..34bdd1b04 100644 --- a/pymllm/configs/server_config.py +++ b/pymllm/configs/server_config.py @@ -76,6 +76,7 @@ class ServerConfig: log_level: Literal["debug", "info", "warning", "error", "critical"] = "info" enable_metrics: bool = False show_time_cost: bool = False + enable_debug_timing: bool = False # Log prefill/decode throughput stats every N decode batches (0 = disabled) decode_log_interval: int = 40 diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py index 2178afa99..a50baa13e 100644 --- a/pymllm/executor/model_runner.py +++ b/pymllm/executor/model_runner.py @@ -487,7 +487,12 @@ def _load_quant_config_dict(model_path: str) -> dict: fpath = model_path / fname if fpath.exists(): with open(fpath) as fp: - return json.load(fp) + cfg = json.load(fp) + # config.json stores model metadata at the top level and + # nests quantization details under quantization_config. + if fname == "config.json" and "quantization_config" in cfg: + return cfg["quantization_config"] + return cfg # Fallback: config.json → quantization_config section config_path = model_path / "config.json" diff --git a/pymllm/layers/__init__.py b/pymllm/layers/__init__.py index 2ecb13965..d328ca7ef 100644 --- a/pymllm/layers/__init__.py +++ b/pymllm/layers/__init__.py @@ -3,7 +3,12 @@ from pymllm.layers.base import MllmBaseLayer from pymllm.layers.embedding import VocabParallelEmbedding from pymllm.layers.layer_norm import LayerNorm -from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear +from pymllm.layers.linear import ( + ColumnParallelLinear, + Linear, + MergedLinear, + RowParallelLinear, +) from pymllm.layers.mlp import MLP, ParallelMLP from pymllm.layers.rms_norm import GemmaRMSNorm, RMSNorm from pymllm.layers.rms_norm_gated import RMSNormGated @@ -38,6 +43,7 @@ "VocabParallelEmbedding", "ColumnParallelLinear", "Linear", + "MergedLinear", "RowParallelLinear", "MLP", "ParallelMLP", diff --git a/pymllm/layers/linear.py b/pymllm/layers/linear.py index b4058c2da..6e4106100 100644 --- a/pymllm/layers/linear.py +++ b/pymllm/layers/linear.py @@ -314,3 +314,154 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return self.quant_method.apply(self, x, self.bias) + + +class MergedLinear(MllmBaseLayer): + """Non-parallel merged linear layer. + + This is the single-GPU counterpart of SGLang/vLLM merged column + projections. It owns one physical parameter set, while + ``output_partition_sizes`` records the logical shards, e.g. + ``[q_size, k_size, v_size]`` or ``[intermediate_size, intermediate_size]``. + Checkpoints may still store those shards as separate tensors; the + shard-aware loader stacks them into the fused parameter. + """ + + def __init__( + self, + in_features: int, + output_partition_sizes: list[int], + bias: bool = True, + quant_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + if not output_partition_sizes: + raise ValueError("output_partition_sizes must not be empty") + if any(size <= 0 for size in output_partition_sizes): + raise ValueError( + "all output_partition_sizes must be positive, got " + f"{output_partition_sizes}" + ) + + self.in_features = in_features + self.output_partition_sizes = list(output_partition_sizes) + self.out_features = sum(self.output_partition_sizes) + + self.quant_method = quant_method or UnquantizedLinearMethod() + self.quant_method.create_weights( + layer=self, + input_size_per_partition=in_features, + output_partition_sizes=self.output_partition_sizes, + input_size=in_features, + output_size=self.out_features, + params_dtype=torch.get_default_dtype(), + weight_loader=self.weight_loader, + ) + + if bias: + self.bias = Parameter(torch.empty(self.out_features)) + set_weight_attrs( + self.bias, + {"output_dim": 0, "weight_loader": self.weight_loader}, + ) + else: + self.register_parameter("bias", None) + + def _actual_offset_for_shard( + self, + param: Parameter, + loaded_weight: torch.Tensor, + output_dim: int, + loaded_shard_id, + ) -> tuple[int, int]: + """Return offset/size in the parameter's actual output dimension.""" + shard_size = loaded_weight.shape[output_dim] + total_size = param.data.shape[output_dim] + + if isinstance(loaded_shard_id, str): + if loaded_shard_id == "q": + return 0, shard_size + if loaded_shard_id == "k": + return total_size - 2 * shard_size, shard_size + if loaded_shard_id == "v": + return total_size - shard_size, shard_size + raise ValueError(f"Unknown QKV shard id: {loaded_shard_id!r}") + + if not isinstance(loaded_shard_id, int): + raise ValueError(f"Unknown shard id: {loaded_shard_id!r}") + if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_partition_sizes): + raise ValueError( + f"shard id {loaded_shard_id} out of range for " + f"{len(self.output_partition_sizes)} partitions" + ) + + logical_total = sum(self.output_partition_sizes) + if total_size == logical_total: + offset = sum(self.output_partition_sizes[:loaded_shard_id]) + elif total_size * self.output_partition_sizes[loaded_shard_id] == ( + logical_total * shard_size + ): + offset = sum( + part * total_size // logical_total + for part in self.output_partition_sizes[:loaded_shard_id] + ) + else: + # Gate/up packed shards are equal-width in the current models. + offset = loaded_shard_id * shard_size + return offset, shard_size + + def _load_unsharded_metadata( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id, + ) -> None: + if param.data.shape != loaded_weight.shape: + raise AssertionError( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + + if loaded_shard_id is not None and param.data.numel() == 2: + fused_shape = loaded_weight.detach().clone().reshape(-1) + fused_shape[0] = self.out_features + fused_shape[1] = self.in_features + param.data.copy_(fused_shape.reshape_as(param.data).to(param.data.dtype)) + return + + param.data.copy_(loaded_weight) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id=None, + ) -> None: + output_dim = getattr(param, "output_dim", None) + + if loaded_shard_id is None: + if param.data.shape != loaded_weight.shape: + raise AssertionError( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + return + + if output_dim is None: + self._load_unsharded_metadata(param, loaded_weight, loaded_shard_id) + return + + shard_offset, shard_size = self._actual_offset_for_shard( + param, loaded_weight, output_dim, loaded_shard_id + ) + param_data = param.data.narrow(output_dim, shard_offset, shard_size) + if param_data.shape != loaded_weight.shape: + raise AssertionError( + f"Shard shape mismatch: param {param_data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param_data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.quant_method.apply(self, x, self.bias) diff --git a/pymllm/layers/mlp.py b/pymllm/layers/mlp.py index 1894e23ca..514b55c9e 100644 --- a/pymllm/layers/mlp.py +++ b/pymllm/layers/mlp.py @@ -7,7 +7,12 @@ import torch from pymllm.layers.base import MllmBaseLayer -from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear +from pymllm.layers.linear import ( + ColumnParallelLinear, + Linear, + MergedLinear, + RowParallelLinear, +) logger = logging.getLogger(__name__) @@ -73,12 +78,6 @@ def __init__( super().__init__() _validate_mlp_args(hidden_size, intermediate_size, activation) - # Quantized checkpoints store gate_proj / up_proj separately; - # fusing them into a single packed-int32 parameter is impractical, - # so force the unfused path when quantisation is active. - if quant_config is not None: - use_fused_gate_up_proj = False - self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.activation = activation @@ -99,8 +98,10 @@ def _get_qm(suffix): ) if use_fused_gate_up_proj: - self.gate_up_proj = Linear( - hidden_size, 2 * intermediate_size, bias=use_bias_gate_up, + self.gate_up_proj = MergedLinear( + hidden_size, + [intermediate_size, intermediate_size], + bias=use_bias_gate_up, quant_method=_get_qm("gate_up_proj"), ) self.gate_proj = None diff --git a/pymllm/layers/rms_norm.py b/pymllm/layers/rms_norm.py index b20b36f30..e9a4c6ed0 100644 --- a/pymllm/layers/rms_norm.py +++ b/pymllm/layers/rms_norm.py @@ -10,6 +10,17 @@ from pymllm.layers.utils import set_weight_attrs +def _torch_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + x_fp32 = x.float() + var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + x_norm = x_fp32 * torch.rsqrt(var + eps) + return x_norm.to(dtype=x.dtype) * weight + + class RMSNorm(MllmBaseLayer): """RMSNorm layer implemented with FlashInfer kernel.""" @@ -26,24 +37,33 @@ def forward( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if residual is not None: - flashinfer.norm.fused_add_rmsnorm(x, residual, self.weight.data, self.eps) - return x, residual - if x.shape[-1] != self.hidden_size: raise ValueError( f"Expected last dim == hidden_size ({self.hidden_size}), " f"but got input shape {tuple(x.shape)}" ) - # FlashInfer rmsnorm accepts 2D/3D input; flatten higher-rank tensors to 2D. - if x.dim() in (2, 3): - return flashinfer.norm.rmsnorm(x, self.weight, self.eps) - - original_shape = x.shape - x_2d = x.reshape(-1, self.hidden_size) - out = flashinfer.norm.rmsnorm(x_2d, self.weight, self.eps) - return out.reshape(original_shape) + if residual is not None: + try: + flashinfer.norm.fused_add_rmsnorm( + x, residual, self.weight.data, self.eps + ) + return x, residual + except Exception: + residual = x + residual + return _torch_rmsnorm(residual, self.weight, self.eps), residual + + try: + # FlashInfer rmsnorm accepts 2D/3D input; flatten higher-rank tensors to 2D. + if x.dim() in (2, 3): + return flashinfer.norm.rmsnorm(x, self.weight, self.eps) + + original_shape = x.shape + x_2d = x.reshape(-1, self.hidden_size) + out = flashinfer.norm.rmsnorm(x_2d, self.weight, self.eps) + return out.reshape(original_shape) + except Exception: + return _torch_rmsnorm(x, self.weight, self.eps) class GemmaRMSNorm(MllmBaseLayer): diff --git a/pymllm/models/__init__.py b/pymllm/models/__init__.py index 7751b3091..00ed27263 100644 --- a/pymllm/models/__init__.py +++ b/pymllm/models/__init__.py @@ -17,6 +17,10 @@ # (module_path, class_name) _MODEL_REGISTRY: Dict[str, Tuple[str, str]] = { + "Qwen3ForCausalLM": ( + "pymllm.models.qwen3", + "Qwen3ForCausalLM", + ), "Qwen3VLForConditionalGeneration": ( "pymllm.models.qwen3_vl", "Qwen3VLForConditionalGeneration", diff --git a/pymllm/models/qwen3.py b/pymllm/models/qwen3.py new file mode 100644 index 000000000..c17d32dd5 --- /dev/null +++ b/pymllm/models/qwen3.py @@ -0,0 +1,453 @@ +"""Inference-only Qwen3 text model for pymllm. + +Implements Qwen3ForCausalLM with: +- QK-norm attention + 1D RoPE +- RadixAttention KV-cache backend +- Optional quantized Linear methods via quant_config + +Adapted from pymllm's Qwen3-VL text backbone and SGLang's qwen3.py. +""" + +from __future__ import annotations + +import logging +import time +from typing import Iterable, Tuple + +import torch +import torch.nn as nn + +from pymllm.layers import RMSNorm +from pymllm.layers.attention.radix_attention import RadixAttention +from pymllm.layers.linear import Linear, MergedLinear +from pymllm.layers.mlp import MLP +from pymllm.layers.rope import apply_rope_pos_ids + +logger = logging.getLogger(__name__) + + +class Qwen3Attention(nn.Module): + """Qwen3 attention with QK norm + 1D RoPE.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_id: int, + rope_theta: float = 1_000_000.0, + rms_norm_eps: float = 1e-6, + max_position_embeddings: int = 32768, + attention_bias: bool = False, + quant_config=None, + prefix: str = "", + ): + del max_position_embeddings + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.scaling = head_dim**-0.5 + self.rope_theta = rope_theta + + def _get_qm(suffix: str): + if quant_config is None: + return None + return quant_config.get_quant_method( + layer=None, + prefix=f"{prefix}.{suffix}" if prefix else suffix, + ) + + self.use_fused_qkv = True + + if self.use_fused_qkv: + self.qkv_proj = MergedLinear( + hidden_size, + [self.q_size, self.kv_size, self.kv_size], + bias=attention_bias, + quant_method=_get_qm("qkv_proj"), + ) + self.q_proj = None + self.k_proj = None + self.v_proj = None + else: + self.qkv_proj = None + self.q_proj = Linear( + hidden_size, + self.q_size, + bias=attention_bias, + quant_method=_get_qm("q_proj"), + ) + self.k_proj = Linear( + hidden_size, + self.kv_size, + bias=attention_bias, + quant_method=_get_qm("k_proj"), + ) + self.v_proj = Linear( + hidden_size, + self.kv_size, + bias=attention_bias, + quant_method=_get_qm("v_proj"), + ) + + self.o_proj = Linear( + self.q_size, + hidden_size, + bias=attention_bias, + quant_method=_get_qm("o_proj"), + ) + + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + self.attn = RadixAttention( + num_heads=num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch, + ) -> torch.Tensor: + if self.use_fused_qkv: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)) + + # Qwen3 text uses 1D RoPE with position ids from scheduler/model runner. + if positions.ndim > 1: + positions = positions[0] + apply_rope_pos_ids( + q, + k, + positions, + inplace=True, + rotary_dim=self.head_dim, + rope_theta=self.rope_theta, + ) + + q = q.reshape(-1, self.q_size) + k = k.reshape(-1, self.kv_size) + + attn_output = self.attn(q, k, v, forward_batch) + return self.o_proj(attn_output) + + +class Qwen3DecoderLayer(nn.Module): + """Single Qwen3 decoder layer.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + intermediate_size: int, + hidden_act: str, + attention_bias: bool, + layer_id: int, + rope_theta: float = 1_000_000.0, + rms_norm_eps: float = 1e-6, + max_position_embeddings: int = 32768, + quant_config=None, + prefix: str = "", + ): + super().__init__() + self.self_attn = Qwen3Attention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + layer_id=layer_id, + rope_theta=rope_theta, + rms_norm_eps=rms_norm_eps, + max_position_embeddings=max_position_embeddings, + attention_bias=attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.self_attn" if prefix else "self_attn", + ) + self.mlp = MLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation=hidden_act, + use_fused_gate_up_proj=True, + use_bias_gate_up=False, + use_bias_down=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp" if prefix else "mlp", + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch, + residual: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions, hidden_states, forward_batch) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class Qwen3Model(nn.Module): + """Qwen3 text backbone (embedding + decoder + final norm).""" + + def __init__(self, config, quant_config=None): + super().__init__() + tc = getattr(config, "text_config", config) + + self.hidden_size = tc.hidden_size + self.num_hidden_layers = tc.num_hidden_layers + + self.embed_tokens = nn.Embedding(tc.vocab_size, tc.hidden_size) + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + hidden_size=tc.hidden_size, + num_heads=tc.num_attention_heads, + num_kv_heads=tc.num_key_value_heads, + head_dim=getattr(tc, "head_dim", tc.hidden_size // tc.num_attention_heads), + intermediate_size=tc.intermediate_size, + hidden_act=tc.hidden_act, + attention_bias=getattr(tc, "attention_bias", False), + layer_id=layer_id, + rope_theta=getattr(tc, "rope_theta", 1_000_000.0), + rms_norm_eps=tc.rms_norm_eps, + max_position_embeddings=getattr(tc, "max_position_embeddings", 32768), + quant_config=quant_config, + prefix=f"model.layers.{layer_id}", + ) + for layer_id in range(tc.num_hidden_layers) + ] + ) + self.norm = RMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch, + input_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + residual = None + for layer in self.layers: + if residual is not None and not isinstance(layer, Qwen3DecoderLayer): + hidden_states = hidden_states + residual + residual = None + + if isinstance(layer, Qwen3DecoderLayer): + layer_output = layer( + positions, + hidden_states, + forward_batch, + residual=residual, + ) + else: + layer_output = layer(positions, hidden_states, forward_batch) + + if isinstance(layer_output, tuple): + hidden_states, residual = layer_output + else: + hidden_states = layer_output + residual = None + + if residual is None: + return self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3ForCausalLM(nn.Module): + """Inference-only Qwen3ForCausalLM.""" + + def __init__(self, config, quant_config=None): + super().__init__() + tc = getattr(config, "text_config", config) + + self.config = tc + self.quant_config = quant_config + + self.model = Qwen3Model(tc, quant_config=quant_config) + + tie_word_embeddings = getattr(config, "tie_word_embeddings", getattr(tc, "tie_word_embeddings", False)) + if tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = nn.Linear(tc.hidden_size, tc.vocab_size, bias=False) + + def get_input_embeddings(self) -> nn.Module: + return self.model.embed_tokens + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch, + ): + _llm_t0 = time.perf_counter() + hidden_states = self.model(input_ids, positions, forward_batch) + _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0 + + if forward_batch.forward_mode.is_extend(): + forward_batch.llm_prefill_ms = _llm_ms + forward_batch.llm_decode_ms = None + else: + forward_batch.llm_decode_ms = _llm_ms + + # Prefill: keep only last token logits per sequence. + if forward_batch.forward_mode.is_extend(): + if ( + getattr(forward_batch, "extend_start_loc", None) is not None + and getattr(forward_batch, "extend_seq_lens", None) is not None + ): + last_index = ( + forward_batch.extend_start_loc + forward_batch.extend_seq_lens - 1 + ).long() + hidden_states = hidden_states[last_index] + else: + hidden_states = hidden_states[-1:] + + logits = torch.matmul( + hidden_states.to(self.lm_head.weight.dtype), + self.lm_head.weight.T, + ) + + from pymllm.executor.model_runner import LogitsProcessorOutput + + return LogitsProcessorOutput(next_token_logits=logits) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # Keep compatibility with checkpoints that omit the model prefix. + if not name.startswith("model.") and ( + name.startswith("layers.") + or name.startswith("embed_tokens.") + or name.startswith("norm.") + ): + name = f"model.{name}" + + if tie_word_embeddings and "lm_head.weight" in name: + continue + + name = _remap_weight_name(name) + + handled = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + mapped_name = name.replace(weight_name, param_name) + if mapped_name not in params_dict: + continue + param = params_dict[mapped_name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, loaded_weight, shard_id) + else: + _load_stacked_weight(param, loaded_weight, shard_id) + handled = True + break + + if handled: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, loaded_weight) + elif param.data.shape == loaded_weight.shape: + param.data.copy_(loaded_weight) + else: + logger.warning( + "Shape mismatch: param %s (%s) vs loaded (%s), skipping.", + name, + tuple(param.data.shape), + tuple(loaded_weight.shape), + ) + + +def _remap_weight_name(name: str) -> str: + """Remap checkpoint weight names to pymllm Qwen3 parameter names.""" + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "model.", 1) + elif name.startswith("language_model."): + name = name.replace("language_model.", "model.", 1) + return name + + +def _load_stacked_weight( + param: nn.Parameter, + loaded_weight: torch.Tensor, + shard_id, +) -> None: + """Load one shard into a fused parameter (QKV or gate_up).""" + if isinstance(shard_id, str): + # QKV fused layout: [Q, K, V] where Q may be wider than K/V in GQA. + total_size = param.data.shape[0] + shard_size = loaded_weight.shape[0] + if shard_id == "q": + param.data[0:shard_size].copy_(loaded_weight) + elif shard_id == "k": + kv_size = shard_size + q_size = total_size - 2 * kv_size + param.data[q_size : q_size + kv_size].copy_(loaded_weight) + elif shard_id == "v": + kv_size = shard_size + q_size = total_size - 2 * kv_size + param.data[q_size + kv_size : q_size + 2 * kv_size].copy_(loaded_weight) + else: + # gate_up fused layout: [gate, up] + shard_size = loaded_weight.shape[0] + param.data[shard_id * shard_size : (shard_id + 1) * shard_size].copy_( + loaded_weight + ) diff --git a/pymllm/models/qwen3_vl.py b/pymllm/models/qwen3_vl.py index b253ad091..2b945fc18 100644 --- a/pymllm/models/qwen3_vl.py +++ b/pymllm/models/qwen3_vl.py @@ -27,6 +27,7 @@ from __future__ import annotations import logging +import time from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import numpy as np @@ -36,7 +37,7 @@ from pymllm.layers import RMSNorm, apply_mrope from pymllm.layers.attention.radix_attention import RadixAttention -from pymllm.layers.linear import Linear +from pymllm.layers.linear import Linear, MergedLinear from pymllm.layers.mlp import MLP if TYPE_CHECKING: @@ -162,8 +163,8 @@ def forward( cos = torch.cat([cos, cos], dim=-1) sin = torch.cat([sin, sin], dim=-1) - cos = cos.unsqueeze(1) # [seq, 1, head_dim] - sin = sin.unsqueeze(1) # [seq, 1, head_dim] + cos = cos.unsqueeze(1).to(dtype=q.dtype, device=q.device) # [seq, 1, head_dim] + sin = sin.unsqueeze(1).to(dtype=q.dtype, device=q.device) # [seq, 1, head_dim] q = q * cos + _rotate_half(q) * sin k = k * cos + _rotate_half(k) * sin @@ -399,10 +400,12 @@ def rot_pos_emb( # -- Position embedding interpolation -- def _get_interpolation_indices(self, dim_size: int) -> np.ndarray: - indices = (np.arange(dim_size, dtype=np.float32) + 0.5) * ( - self.num_grid_per_side / dim_size - ) - 0.5 - return np.clip(indices, 0, self.num_grid_per_side - 1) + return np.linspace( + 0, + self.num_grid_per_side - 1, + dim_size, + dtype=np.float32, + ) def _calculate_indices_and_weights( self, h_idxs: np.ndarray, w_idxs: np.ndarray @@ -548,7 +551,9 @@ def forward( def _compute_cu_seqlens_from_grid(grid_thw: torch.Tensor) -> torch.Tensor: """Compute cumulative sequence lengths from grid dimensions.""" grid_np = grid_thw.cpu().numpy() - seq_lens = (grid_np[:, 0] * grid_np[:, 1] * grid_np[:, 2]).astype(np.int32) + seq_lens = np.repeat(grid_np[:, 1] * grid_np[:, 2], grid_np[:, 0]).astype( + np.int32 + ) cu_seqlens = np.concatenate([[0], np.cumsum(seq_lens)]) return torch.tensor(cu_seqlens, dtype=torch.int32) @@ -719,13 +724,14 @@ def _get_qm(suffix): layer=None, prefix=f"{prefix}.{suffix}" if prefix else suffix, ) - # When quantized, AWQ checkpoints store q/k/v separately so we - # cannot fuse them into a single packed-int32 parameter. - self.use_fused_qkv = quant_config is None + self.use_fused_qkv = True if self.use_fused_qkv: - self.qkv_proj = Linear( - hidden_size, self.q_size + 2 * self.kv_size, bias=False, + self.qkv_proj = MergedLinear( + hidden_size, + [self.q_size, self.kv_size, self.kv_size], + bias=False, + quant_method=_get_qm("qkv_proj"), ) self.q_proj = None self.k_proj = None @@ -861,25 +867,21 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: "ForwardBatch", - deepstack_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Self-attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(positions, hidden_states, forward_batch) - hidden_states = residual + hidden_states - - # Add deepstack embeddings after residual (matches HF ordering) - if deepstack_embeds is not None: - hidden_states = hidden_states + deepstack_embeds + residual: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) - # MLP - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, forward_batch) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class Qwen3VLTextModel(nn.Module): @@ -943,18 +945,41 @@ def forward( else: hidden_states = input_embeds + residual = None for layer_idx, layer in enumerate(self.layers): + if residual is not None and not isinstance(layer, Qwen3VLDecoderLayer): + hidden_states = hidden_states + residual + residual = None + + if isinstance(layer, Qwen3VLDecoderLayer): + layer_output = layer( + positions, + hidden_states, + forward_batch, + residual=residual, + ) + else: + layer_output = layer(positions, hidden_states, forward_batch) + + if isinstance(layer_output, tuple): + hidden_states, residual = layer_output + else: + hidden_states = layer_output + residual = None + ds_embeds = _get_deepstack_embeds( layer_idx, input_deepstack_embeds, self.hidden_size ) - hidden_states = layer( - positions, - hidden_states, - forward_batch, - deepstack_embeds=ds_embeds, - ) - - return self.norm(hidden_states) + if ds_embeds is not None: + if residual is not None: + hidden_states = hidden_states + residual + residual = None + hidden_states = hidden_states + ds_embeds + + if residual is None: + return self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states def _get_deepstack_embeds( @@ -996,7 +1021,6 @@ def __init__(self, config, quant_config=None) -> None: text_config = getattr(config, "text_config", config) vision_config = getattr(config, "vision_config", None) - # Vision encoder — NOT quantized if vision_config is not None: self.visual = Qwen3VLVisionModel( @@ -1051,7 +1075,6 @@ def __init__(self, config, quant_config=None) -> None: max_position_embeddings=max_position_embeddings, quant_config=quant_config, ) - # LM head — following sglang's pattern: always use lm_head.weight # for matmul in forward(), so it works whether lm_head is nn.Embedding # (tied) or nn.Linear (untied). @@ -1158,6 +1181,10 @@ def forward( input_embeds = None input_deepstack_embeds = None + vit_prefill_ms = None + vit_prefill_tokens = None + llm_prefill_ms = None + llm_decode_ms = None if ( pixel_values is not None @@ -1166,7 +1193,11 @@ def forward( and not forward_batch.forward_mode.is_decode() ): # Run vision encoder - vision_features = self.visual(pixel_values, grid_thw=image_grid_thw) + _vit_t0 = time.perf_counter() + vision_features = ( + self.visual(pixel_values, grid_thw=image_grid_thw) + ) + vit_prefill_ms = (time.perf_counter() - _vit_t0) * 1000.0 # Separate main embeddings and deepstack embeddings if self.num_deepstack_embeddings > 0: @@ -1179,6 +1210,13 @@ def forward( # Get text embeddings and replace image tokens with vision features input_embeds = self.model.embed_tokens(input_ids) image_mask = input_ids == self.image_token_id + vit_prefill_tokens = int(image_mask.sum().item()) + if vit_prefill_tokens != int(vision_embeds.shape[0]): + raise ValueError( + "Image features and image tokens do not match, " + f"tokens: {vit_prefill_tokens}, " + f"features: {vision_embeds.shape[0]}" + ) if image_mask.any(): input_embeds[image_mask] = vision_embeds.to(input_embeds.dtype) @@ -1195,13 +1233,27 @@ def forward( ) # Text decoder - hidden_states = self.model( - input_ids, - positions, - forward_batch, - input_embeds=input_embeds, - input_deepstack_embeds=input_deepstack_embeds, + _llm_t0 = time.perf_counter() + hidden_states = ( + self.model( + input_ids, + positions, + forward_batch, + input_embeds=input_embeds, + input_deepstack_embeds=input_deepstack_embeds, + ) ) + _llm_ms = (time.perf_counter() - _llm_t0) * 1000.0 + + if forward_batch.forward_mode.is_extend(): + llm_prefill_ms = _llm_ms + forward_batch.vit_prefill_ms = vit_prefill_ms + forward_batch.vit_prefill_tokens = vit_prefill_tokens + forward_batch.llm_prefill_ms = llm_prefill_ms + forward_batch.llm_decode_ms = None + else: + llm_decode_ms = _llm_ms + forward_batch.llm_decode_ms = llm_decode_ms # Prune hidden_states before lm_head to avoid a wasteful # [total_tokens, vocab] matmul during prefill. @@ -1240,19 +1292,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: Handles weight name remapping between HuggingFace Qwen3-VL checkpoints and this model's parameter names. """ - # When quantized, the model has separate q/k/v and gate/up projections - # (no fused qkv_proj / gate_up_proj), so skip the stacking logic. - if self.quant_config is not None: - stacked_params_mapping = [] - else: - stacked_params_mapping = [ - # (param_name, weight_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".up_proj", 1), - (".gate_up_proj", ".gate_proj", 0), - ] + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] params_dict = dict(self.named_parameters()) @@ -1277,7 +1324,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: name = name.replace(weight_name, param_name) if name not in params_dict: continue - _load_stacked_weight(params_dict[name], loaded_weight, shard_id) + param = params_dict[name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, loaded_weight, shard_id) + else: + _load_stacked_weight(param, loaded_weight, shard_id) handled = True break @@ -1332,9 +1384,10 @@ def _remap_weight_name(name: str) -> str: elif name.startswith("model.visual."): name = name.replace("model.visual.", "visual.", 1) - # Vision attention QKV renaming (fused weights in checkpoint) + # Vision attention param renaming (checkpoint -> pymllm names) if "visual" in name: name = name.replace("attn.qkv.", "attn.qkv_proj.") + name = name.replace("attn.proj.", "attn.out_proj.") return name diff --git a/pymllm/orchestrator/detokenizer_process.py b/pymllm/orchestrator/detokenizer_process.py index 1bbda98d0..7b8bf263f 100644 --- a/pymllm/orchestrator/detokenizer_process.py +++ b/pymllm/orchestrator/detokenizer_process.py @@ -116,6 +116,10 @@ def _detokenize(self, token_id_out: Dict[str, Any]) -> List[Dict[str, Any]]: ) prompt_tokens_list: List[int] = token_id_out.get("prompt_tokens", []) completion_tokens_list: List[int] = token_id_out.get("completion_tokens", []) + vit_prefill_ms_list = token_id_out.get("vit_prefill_ms", []) + vit_prefill_tokens_list = token_id_out.get("vit_prefill_tokens", []) + llm_prefill_ms_list = token_id_out.get("llm_prefill_ms", []) + llm_decode_ms_list = token_id_out.get("llm_decode_ms", []) results: List[Dict[str, Any]] = [] @@ -131,6 +135,18 @@ def _detokenize(self, token_id_out: Dict[str, Any]) -> List[Dict[str, Any]]: completion_tokens = ( completion_tokens_list[i] if i < len(completion_tokens_list) else 0 ) + vit_prefill_ms = ( + vit_prefill_ms_list[i] if i < len(vit_prefill_ms_list) else None + ) + vit_prefill_tokens = ( + vit_prefill_tokens_list[i] if i < len(vit_prefill_tokens_list) else None + ) + llm_prefill_ms = ( + llm_prefill_ms_list[i] if i < len(llm_prefill_ms_list) else None + ) + llm_decode_ms = ( + llm_decode_ms_list[i] if i < len(llm_decode_ms_list) else None + ) # Decode text from output_ids if self._tokenizer is not None: @@ -160,6 +176,14 @@ def _detokenize(self, token_id_out: Dict[str, Any]) -> List[Dict[str, Any]]: "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, } + if vit_prefill_ms is not None: + result["vit_prefill_ms"] = vit_prefill_ms + if vit_prefill_tokens is not None: + result["vit_prefill_tokens"] = vit_prefill_tokens + if llm_prefill_ms is not None: + result["llm_prefill_ms"] = llm_prefill_ms + if llm_decode_ms is not None: + result["llm_decode_ms"] = llm_decode_ms results.append(result) return results diff --git a/pymllm/orchestrator/model_runner_process.py b/pymllm/orchestrator/model_runner_process.py index a514ac2e9..383fa2da5 100644 --- a/pymllm/orchestrator/model_runner_process.py +++ b/pymllm/orchestrator/model_runner_process.py @@ -20,6 +20,7 @@ """ import logging +import time from typing import Any, Dict, List, Optional, Tuple import torch @@ -373,7 +374,21 @@ def _forward_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: mrope_position_deltas=mrope_deltas_tensor, ) + _forward_t0 = time.perf_counter() logits_output = runner.forward(fb) + _forward_ms = (time.perf_counter() - _forward_t0) * 1000.0 + + # Extract timing info written by multimodal models onto ForwardBatch. + vit_prefill_ms = getattr(fb, "vit_prefill_ms", None) + vit_prefill_tokens = getattr(fb, "vit_prefill_tokens", None) + llm_prefill_ms = getattr(fb, "llm_prefill_ms", None) + llm_decode_ms = getattr(fb, "llm_decode_ms", None) + + # Decode may run through CUDA graph / non-Python execution paths where + # model-level Python timing hooks do not fire. Fall back to the outer + # runner.forward wall-clock time for decode batches. + if forward_mode == "decode" and llm_decode_ms is None: + llm_decode_ms = _forward_ms # Persist M-RoPE position deltas for multimodal models (Qwen3-VL). # The model sets mrope_position_deltas on the ForwardBatch during @@ -424,6 +439,15 @@ def _forward_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: "rid": rid, "output_token_ids": [token_id], } + + if vit_prefill_ms is not None: + out["vit_prefill_ms"] = float(vit_prefill_ms) + if vit_prefill_tokens is not None: + out["vit_prefill_tokens"] = int(vit_prefill_tokens) + if llm_prefill_ms is not None: + out["llm_prefill_ms"] = float(llm_prefill_ms) + if llm_decode_ms is not None: + out["llm_decode_ms"] = float(llm_decode_ms) # Report actual prefix_len back to the scheduler so it can # update its token budget tracking accurately. if actual_prefix_lens is not None: @@ -565,6 +589,11 @@ def _insert_into_radix_cache(self, requests_meta: List[Dict[str, Any]]) -> None: len(new_indices), new_indices[: min(len(new_indices), 8)].tolist(), ) + if not hasattr(cache, "page_size"): + # ChunkCache / no-op cache when disable_radix_cache=True. + self._rid_to_cache_protected_len[rid] = 0 + continue + if cache.page_size == 1: assert len(new_indices) == seq_len, ( f"Re-match length mismatch after insert: " @@ -999,7 +1028,7 @@ def _free_rid_resources(self, rid: str) -> None: # and the eviction callback; here we just remove the rid mapping. self._rid_to_gdn_track_slot.pop(rid, None) - cache_enabled = cache is not None + cache_enabled = cache is not None and not isinstance(cache, ChunkCache) # ---------------------------------------------------------- # Phase 1: Read all KV indices BEFORE freeing anything. diff --git a/pymllm/orchestrator/scheduler_process.py b/pymllm/orchestrator/scheduler_process.py index 3bc3466a1..29e05fe06 100644 --- a/pymllm/orchestrator/scheduler_process.py +++ b/pymllm/orchestrator/scheduler_process.py @@ -123,6 +123,12 @@ class Req: "read_offset", # Prompt length (for token accounting) "prompt_len", + # Timing stats + "vit_prefill_ms", + "vit_prefill_tokens", + "llm_prefill_ms", + "llm_decode_ms", + "decode_start_tic", ) def __init__( @@ -175,6 +181,13 @@ def __init__( # Prompt length self.prompt_len: int = len(input_ids) + # Timing stats + self.vit_prefill_ms = None + self.vit_prefill_tokens = None + self.llm_prefill_ms = None + self.llm_decode_ms = None + self.decode_start_tic = None + def check_finished(self) -> bool: """Check if this request has reached a finish condition. @@ -776,6 +789,15 @@ def process_batch_result( # The model runner reports the actual prefix_len it found. # The scheduler originally reserved full input_len in # get_next_batch_to_run; correct the over-reservation now. + if "vit_prefill_ms" in out: + req.vit_prefill_ms = out["vit_prefill_ms"] + if "vit_prefill_tokens" in out: + req.vit_prefill_tokens = out["vit_prefill_tokens"] + if "llm_prefill_ms" in out: + req.llm_prefill_ms = out["llm_prefill_ms"] + if "llm_decode_ms" in out: + req.llm_decode_ms = (req.llm_decode_ms or 0.0) + out["llm_decode_ms"] + if "prefix_len" in out and batch.forward_mode.is_extend(): actual_prefix_len = out["prefix_len"] if actual_prefix_len > req.prefix_len: @@ -808,6 +830,12 @@ def process_batch_result( # Check finish conditions (EOS tokens already in stop_token_ids) req.check_finished() + if batch.forward_mode.is_decode(): + _decode_now = time.perf_counter() + for req in batch.reqs: + if req.decode_start_tic is not None: + req.llm_decode_ms = (_decode_now - req.decode_start_tic) * 1000.0 + # Process batch requests based on forward mode if batch.forward_mode.is_extend(): # Prefill batch: mark as prefilled and route @@ -818,6 +846,9 @@ def process_batch_result( self._model_runner._free_rid_resources(req.rid) self._free_req_resources(req) else: + if req.decode_start_tic is None: + req.decode_start_tic = time.perf_counter() + req.llm_decode_ms = 0.0 self._running_batch.append(req) # --- Accumulate prefill metrics --- @@ -876,6 +907,10 @@ def stream_output(self) -> None: "skip_special_tokens": [True], "prompt_tokens": [req.prompt_len], "completion_tokens": [len(req.output_ids)], + "vit_prefill_ms": [req.vit_prefill_ms], + "vit_prefill_tokens": [req.vit_prefill_tokens], + "llm_prefill_ms": [req.llm_prefill_ms], + "llm_decode_ms": [req.llm_decode_ms], } req.read_offset = len(req.output_ids) self._send_to_detokenizer.send_pyobj(output) @@ -952,6 +987,10 @@ def _collect_finished_output(self, req: Req) -> None: "skip_special_tokens": [True], "prompt_tokens": [req.prompt_len], "completion_tokens": [len(req.output_ids)], + "vit_prefill_ms": [req.vit_prefill_ms], + "vit_prefill_tokens": [req.vit_prefill_tokens], + "llm_prefill_ms": [req.llm_prefill_ms], + "llm_decode_ms": [req.llm_decode_ms], } self._finished.append(output) logger.debug( diff --git a/pymllm/orchestrator/tokenizer_process.py b/pymllm/orchestrator/tokenizer_process.py index 44a4c897c..e0e4139f1 100644 --- a/pymllm/orchestrator/tokenizer_process.py +++ b/pymllm/orchestrator/tokenizer_process.py @@ -371,6 +371,20 @@ def _tokenize( # ------------------------------------------------------------------ # mm_inputs = self._collect_mm_inputs(raw_request, text=input_text) + # If AutoProcessor produced multimodal input_ids, they must override + # the plain tokenizer result. Otherwise the prompt contains only a + # single image placeholder token and won't match the visual features. + if mm_inputs is not None: + image_inputs = mm_inputs.get("image_inputs") + if image_inputs is not None and "input_ids" in image_inputs: + proc_input_ids = image_inputs["input_ids"] + if hasattr(proc_input_ids, "ndim") and proc_input_ids.ndim > 1: + proc_input_ids = proc_input_ids[0] + if hasattr(proc_input_ids, "tolist"): + input_ids = proc_input_ids.tolist() + else: + input_ids = list(proc_input_ids) + # ------------------------------------------------------------------ # # 3. Pack into the typed dataclass # ------------------------------------------------------------------ # diff --git a/pymllm/quantization/kernels/__init__.py b/pymllm/quantization/kernels/__init__.py new file mode 100644 index 000000000..41b6c5a26 --- /dev/null +++ b/pymllm/quantization/kernels/__init__.py @@ -0,0 +1,3 @@ +# Kernel implementations for quantization methods. +# Triton kernels live here (Python JIT by Triton compiler). +# CUDA/CUTLASS kernels live in mllm-kernel (compiled by mllm JIT/AOT pipeline). diff --git a/pymllm/quantization/kernels/int8_activation_triton.py b/pymllm/quantization/kernels/int8_activation_triton.py new file mode 100644 index 000000000..f0c9accfa --- /dev/null +++ b/pymllm/quantization/kernels/int8_activation_triton.py @@ -0,0 +1,82 @@ +"""Per-token INT8 activation quantization using Triton. + +Ported from sglang int8_kernel.py (per_token_quant_int8). +Original: sglang/srt/layers/quantization/int8_kernel.py:28-89 +""" +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_token_quant_int8( + x_ptr, + xq_ptr, + scale_ptr, + stride_x, + stride_xq, + N, + BLOCK: tl.constexpr, +): + """Triton kernel: per-token dynamic INT8 quantization. + + Each program instance handles one row (token). + Computes absmax, derives scale, quantizes to int8. + """ + row_id = tl.program_id(0) + + cols = tl.arange(0, BLOCK) + mask = cols < N + + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to( + tl.float32 + ) + absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) + scale_x = absmax / 127 + x_q = x * (127 / absmax) + x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) + + tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) + tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty)) + + +def per_token_quant_int8( + x: torch.Tensor, + scale_dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """Per-token dynamic INT8 quantization. + + Args: + x: Input tensor, any shape with last dim = hidden_dim. Must be contiguous. + scale_dtype: Dtype for scale output (default float32). + + Returns: + x_q: INT8 quantized tensor, same shape as x. + scales: Per-token scales, shape = x.shape[:-1] + (1,). + """ + assert x.is_contiguous(), "Input must be contiguous" + + M = x.numel() // x.shape[-1] + N = x.shape[-1] + x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) + scales = torch.empty( + x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype + ) + + BLOCK = triton.next_power_of_2(N) + num_warps = min(max(BLOCK // 256, 1), 8) + + _per_token_quant_int8[(M,)]( + x, + x_q, + scales, + stride_x=x.stride(-2), + stride_xq=x_q.stride(-2), + N=N, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return x_q, scales diff --git a/pymllm/quantization/methods/__init__.py b/pymllm/quantization/methods/__init__.py index 90367f741..3f799deee 100644 --- a/pymllm/quantization/methods/__init__.py +++ b/pymllm/quantization/methods/__init__.py @@ -8,8 +8,14 @@ AWQMarlinConfig, AWQMarlinLinearMethod, ) +from pymllm.quantization.methods.compressed_tensors import ( + CompressedTensorsConfig, + CompressedTensorsLinearMethod, +) __all__ = [ "AWQMarlinConfig", "AWQMarlinLinearMethod", + "CompressedTensorsConfig", + "CompressedTensorsLinearMethod", ] diff --git a/pymllm/quantization/methods/compressed_tensors.py b/pymllm/quantization/methods/compressed_tensors.py new file mode 100644 index 000000000..a46dcf20a --- /dev/null +++ b/pymllm/quantization/methods/compressed_tensors.py @@ -0,0 +1,597 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Parameter + +from mllm_kernel.cuda.jit import gptq_marlin_gemm, gptq_marlin_repack + +from pymllm.layers.quantize_base import LinearMethodBase +from pymllm.layers.utils import set_weight_attrs +from pymllm.quantization.quant_config import QuantizationConfig, register_quantization + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_TILE = 16 + + +class _ScalarTypeInfo: + def __init__(self, name: str, size_bits: int, type_id: int): + self.name = name + self.size_bits = size_bits + self.id = type_id + + +def _compute_scalar_type_id( + exponent: int, + mantissa: int, + signed: bool, + bias: int, + finite_values_only: bool = False, + nan_repr: int = 1, +) -> int: + bit_offset = 0 + result = 0 + for value, width in [ + (exponent, 8), + (mantissa, 8), + (signed, 1), + (bias, 32), + (finite_values_only, 1), + (nan_repr, 8), + ]: + result |= (int(value) & ((1 << width) - 1)) << bit_offset + bit_offset += width + return result + + +SCALAR_TYPE_UINT4 = _ScalarTypeInfo( + "uint4", 4, _compute_scalar_type_id(0, 4, False, 0) +) +SCALAR_TYPE_UINT4B8 = _ScalarTypeInfo( + "uint4b8", 4, _compute_scalar_type_id(0, 4, False, 8) +) + + +def _weights_cfg(config: Dict[str, Any]) -> Dict[str, Any]: + return config["config_groups"]["group_0"]["weights"] + + +def _input_activations_cfg(config: Dict[str, Any]) -> Optional[Dict[str, Any]]: + return config["config_groups"]["group_0"].get("input_activations") + + +def verify_marlin_supported(group_size: int) -> None: + if group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Unsupported compressed-tensors group_size: {group_size}" + ) + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + if major * 10 + minor < 80: + raise ValueError("compressed-tensors Marlin requires SM80+") + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError("output_size_per_partition must be divisible by 64") + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError("input_size_per_partition must be divisible by 128") + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + "input_size_per_partition must be divisible by group_size" + ) + + +def marlin_make_workspace(device: torch.device) -> torch.Tensor: + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms, dtype=torch.int, device=device, requires_grad=False) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return Parameter( + torch.empty(0, dtype=torch.int32, device=device), requires_grad=False + ) + + +def get_scale_perms(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]] + ) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape((-1, size_n)).contiguous() + + +def replace_parameter( + layer: torch.nn.Module, name: str, new_data: torch.Tensor +) -> None: + layer.register_parameter(name, Parameter(new_data, requires_grad=False)) + + +def _per_token_quant_int8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Dynamic per-token INT8 quantization using Triton kernel.""" + return _get_triton_quant()(x) + + +def _int8_scaled_mm( + x_q: torch.Tensor, + w_q_t: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """INT8 scaled matmul using CUTLASS kernel.""" + return _get_cutlass_mm()(x_q, w_q_t, x_scale, w_scale, out_dtype, bias) + + +# Lazy-loaded kernel references (populated on first call, reused after) +_triton_quant_fn = None +_cutlass_mm_fn = None + + +def _get_triton_quant(): + global _triton_quant_fn + if _triton_quant_fn is None: + from pymllm.quantization.kernels.int8_activation_triton import ( + per_token_quant_int8, + ) + _triton_quant_fn = per_token_quant_int8 + return _triton_quant_fn + + +def _get_cutlass_mm(): + global _cutlass_mm_fn + if _cutlass_mm_fn is None: + from mllm_kernel.cuda.jit.int8_scaled_mm_cutlass import ( + int8_scaled_mm, + ) + _cutlass_mm_fn = int8_scaled_mm + return _cutlass_mm_fn + + +def _validate_supported_signature(config: "CompressedTensorsConfig") -> str: + if config.quant_format == "pack-quantized": + if config.weight_bits != 4: + raise ValueError( + f"Unsupported compressed-tensors num_bits: {config.weight_bits}" + ) + if config.group_size != 32: + raise ValueError( + f"Unsupported compressed-tensors group_size: {config.group_size}" + ) + if not config.symmetric: + raise ValueError("v1 only supports symmetric compressed-tensors") + if config.actorder is not None: + raise ValueError( + f"Unsupported compressed-tensors actorder: {config.actorder}" + ) + verify_marlin_supported(config.group_size) + return "w4a16" + + if config.quant_format == "int-quantized": + if config.weight_bits != 8: + raise ValueError( + f"Unsupported compressed-tensors num_bits: {config.weight_bits}" + ) + if config.group_size is not None: + raise ValueError( + f"Unsupported compressed-tensors group_size: {config.group_size}" + ) + if config.weight_strategy != "channel": + raise ValueError( + f"Unsupported compressed-tensors weight strategy: " + f"{config.weight_strategy}" + ) + if config.weight_type != "int": + raise ValueError( + f"Unsupported compressed-tensors weight type: {config.weight_type}" + ) + if config.weight_dynamic: + raise ValueError("compressed-tensors int8 weights must be static") + if not config.symmetric: + raise ValueError("v1 only supports symmetric compressed-tensors") + if config.actorder is not None: + raise ValueError( + f"Unsupported compressed-tensors actorder: {config.actorder}" + ) + if config.input_bits != 8: + raise ValueError( + f"Unsupported compressed-tensors input num_bits: {config.input_bits}" + ) + if config.input_strategy != "token": + raise ValueError( + f"Unsupported compressed-tensors input strategy: " + f"{config.input_strategy}" + ) + if config.input_type != "int": + raise ValueError( + f"Unsupported compressed-tensors input type: {config.input_type}" + ) + if not config.input_dynamic: + raise ValueError("compressed-tensors int8 inputs must be dynamic") + if not config.input_symmetric: + raise ValueError("v1 only supports symmetric compressed-tensors input") + return "w8a8" + + raise ValueError( + f"Unsupported compressed-tensors format: {config.quant_format}" + ) + + +class CompressedTensorsWNA16Scheme: + def __init__( + self, + *, + weight_bits: int, + group_size: int, + symmetric: bool, + actorder: Optional[str], + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.symmetric = symmetric + self.actorder = actorder + self.pack_factor = 32 // weight_bits + self.quant_type = ( + SCALAR_TYPE_UINT4B8 if symmetric else SCALAR_TYPE_UINT4 + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs: Any, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=self.group_size, + ) + + weight_packed = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(weight_packed, {"output_dim": 0, **extra_weight_attrs}) + layer.register_parameter("weight_packed", weight_packed) + + weight_scale = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition // self.group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight_scale, {"output_dim": 0, **extra_weight_attrs}) + layer.register_parameter("weight_scale", weight_scale) + + weight_shape = Parameter(torch.empty(2, dtype=torch.int64), requires_grad=False) + set_weight_attrs(weight_shape, extra_weight_attrs) + layer.register_parameter("weight_shape", weight_shape) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.group_size = self.group_size + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.weight_packed.device + size_k = layer.input_size_per_partition + size_n = layer.output_size_per_partition + + verify_marlin_supports_shape( + output_size_per_partition=size_n, + input_size_per_partition=size_k, + input_size=size_k, + group_size=self.group_size, + ) + + layer.workspace = marlin_make_workspace(device) + layer.weight_zero_point = marlin_make_empty_g_idx(device) + layer.weight_g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + repacked_qweight = gptq_marlin_repack( + layer.weight_packed.data.t().contiguous(), + perm=layer.g_idx_sort_indices, + size_k=size_k, + size_n=size_n, + num_bits=self.weight_bits, + ) + repacked_scales = marlin_permute_scales( + layer.weight_scale.data.t().contiguous(), + size_k=size_k, + size_n=size_n, + group_size=self.group_size, + ) + + replace_parameter(layer, "weight_packed", repacked_qweight) + replace_parameter(layer, "weight_scale", repacked_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (layer.output_size_per_partition,) + output = gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=layer.weight_packed, + b_scales=layer.weight_scale, + global_scale=None, + b_zeros=layer.weight_zero_point, + g_idx=layer.weight_g_idx, + perm=layer.g_idx_sort_indices, + workspace=layer.workspace, + b_q_type_id=self.quant_type.id, + size_m=reshaped_x.shape[0], + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + is_k_full=True, + use_fp32_reduce=True, + is_zp_float=False, + ) + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) + + +class CompressedTensorsW8A8Int8Scheme: + def __init__(self, *, weight_bits: int) -> None: + self.weight_bits = weight_bits + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs: Any, + ) -> None: + del output_size + del params_dtype + + output_size_per_partition = sum(output_partition_sizes) + + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs( + weight, {"input_dim": 1, "output_dim": 0, **extra_weight_attrs} + ) + layer.register_parameter("weight", weight) + + weight_scale = Parameter( + torch.empty( + output_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) + set_weight_attrs(weight_scale, {"output_dim": 0, **extra_weight_attrs}) + layer.register_parameter("weight_scale", weight_scale) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + del input_size + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if layer.weight.dtype != torch.int8: + raise ValueError( + f"compressed-tensors int8 expects weight dtype int8, got " + f"{layer.weight.dtype}" + ) + + # Store weight as (K, N) column-major for CUTLASS: stride(0)==1. + # Original weight is (N, K) row-major. .contiguous() ensures owned memory, + # .t() gives (K, N) with strides (1, K) = column-major. + replace_parameter(layer, "weight", layer.weight.data.contiguous().t()) + + scales = layer.weight_scale.data + if scales.dim() == 2 and scales.shape[1] == 1: + scales = scales[:, 0] + elif scales.dim() != 1: + raise ValueError( + "compressed-tensors int8 expects weight_scale shape [N,1] or [N]" + ) + replace_parameter(layer, "weight_scale", scales.to(torch.float32).contiguous()) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]).contiguous() + out_shape = x.shape[:-1] + (layer.output_size_per_partition,) + + x_q, x_scale = _per_token_quant_int8(reshaped_x) + output = _int8_scaled_mm( + x_q, + layer.weight, + x_scale, + layer.weight_scale, + out_dtype=x.dtype, + bias=bias, + ) + return output.reshape(out_shape) + +class CompressedTensorsLinearMethod(LinearMethodBase): + def __init__( + self, + quant_config: "CompressedTensorsConfig", + signature: str, + ) -> None: + self.quant_config = quant_config + if signature == "w4a16": + self.scheme = CompressedTensorsWNA16Scheme( + weight_bits=quant_config.weight_bits, + group_size=quant_config.group_size, + symmetric=quant_config.symmetric, + actorder=quant_config.actorder, + ) + return + self.scheme = CompressedTensorsW8A8Int8Scheme( + weight_bits=quant_config.weight_bits + ) + + def create_weights(self, *args: Any, **kwargs: Any) -> None: + self.scheme.create_weights(*args, **kwargs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.scheme.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.scheme.apply(layer, x, bias) + + +@register_quantization("compressed-tensors") +class CompressedTensorsConfig(QuantizationConfig): + def __init__( + self, + *, + quant_format: str, + ignore: List[str], + weight_bits: int, + group_size: Optional[int], + weight_strategy: Optional[str], + weight_type: Optional[str], + weight_dynamic: bool, + symmetric: bool, + actorder: Optional[str], + input_bits: Optional[int], + input_strategy: Optional[str], + input_type: Optional[str], + input_dynamic: bool, + input_symmetric: bool, + ) -> None: + super().__init__() + self.quant_format = quant_format + self.ignore = ignore + self.weight_bits = weight_bits + self.group_size = group_size + self.weight_strategy = weight_strategy + self.weight_type = weight_type + self.weight_dynamic = weight_dynamic + self.symmetric = symmetric + self.actorder = actorder + self.input_bits = input_bits + self.input_strategy = input_strategy + self.input_type = input_type + self.input_dynamic = input_dynamic + self.input_symmetric = input_symmetric + + def get_name(self) -> str: + return "compressed-tensors" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @staticmethod + def get_config_filenames() -> List[str]: + return ["config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + weights = _weights_cfg(config) + input_activations = _input_activations_cfg(config) + return cls( + quant_format=config["format"], + ignore=list(config.get("ignore", [])), + weight_bits=weights["num_bits"], + group_size=weights["group_size"], + weight_strategy=weights.get("strategy"), + weight_type=weights.get("type"), + weight_dynamic=bool(weights.get("dynamic", False)), + symmetric=weights["symmetric"], + actorder=weights.get("actorder"), + input_bits=( + input_activations.get("num_bits") + if input_activations is not None + else None + ), + input_strategy=( + input_activations.get("strategy") + if input_activations is not None + else None + ), + input_type=( + input_activations.get("type") + if input_activations is not None + else None + ), + input_dynamic=bool( + input_activations.get("dynamic", False) + if input_activations is not None + else False + ), + input_symmetric=bool( + input_activations.get("symmetric", False) + if input_activations is not None + else False + ), + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str = "" + ) -> Optional[CompressedTensorsLinearMethod]: + signature = _validate_supported_signature(self) + if any(ignored and prefix.startswith(ignored) for ignored in self.ignore): + return None + return CompressedTensorsLinearMethod(self, signature) diff --git a/pymllm/server/launch.py b/pymllm/server/launch.py index 7f756d46d..fe35a70f4 100644 --- a/pymllm/server/launch.py +++ b/pymllm/server/launch.py @@ -419,6 +419,56 @@ def _normalize_finish_reason(reason: Optional[str]) -> Optional[str]: return _FINISH_REASON_MAP.get(reason, reason) +def _debug_tps(tokens: int, ms: Optional[float]) -> Optional[float]: + if ms is None or ms <= 0: + return None + return tokens / (ms / 1000.0) + + +def _build_debug_timing( + result: Dict[str, Any], + *, + prompt_tokens: int, + completion_tokens: int, +) -> Dict[str, Any]: + vit_prefill_ms = result.get("vit_prefill_ms") + vit_prefill_tokens = result.get("vit_prefill_tokens") + llm_prefill_ms = result.get("llm_prefill_ms") + llm_decode_ms = result.get("llm_decode_ms") + + return { + "experimental_vit_prefill_ms": vit_prefill_ms, + "experimental_llm_prefill_ms": llm_prefill_ms, + "decode_phase_wall_ms": llm_decode_ms, + "prefill_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "experimental_vit_prefill_tps": ( + None + if vit_prefill_tokens is None + else _debug_tps(int(vit_prefill_tokens), vit_prefill_ms) + ), + "experimental_llm_prefill_tps": _debug_tps(prompt_tokens, llm_prefill_ms), + "decode_phase_output_tps": _debug_tps(completion_tokens, llm_decode_ms), + } + + +def _maybe_add_debug_timing( + payload: Dict[str, Any], + *, + result: Dict[str, Any], + prompt_tokens: int, + completion_tokens: int, +) -> Dict[str, Any]: + cfg = get_global_config() + if cfg.server.enable_debug_timing: + payload["debug_timing"] = _build_debug_timing( + result, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + return payload + + def _build_sampling_params( temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -470,14 +520,21 @@ def _messages_to_prompt( Extra keyword arguments forwarded to ``apply_chat_template`` (e.g. ``enable_thinking=True`` for Qwen3). """ - # Flatten each message into a plain dict for the tokenizer. + # Preserve multimodal message structure for tokenizer.apply_chat_template. msg_dicts: List[Dict[str, Any]] = [] for msg in messages: content = msg.content if isinstance(content, list): - # Multimodal: extract only text parts for the prompt string. - text_parts = [p.text for p in content if p.type == "text" and p.text] - content = "\n".join(text_parts) if text_parts else "" + mm_parts: List[Dict[str, Any]] = [] + for part in content: + if part.type == "text" and part.text is not None: + mm_parts.append({"type": "text", "text": part.text}) + elif part.type == "image_url" and part.image_url is not None: + # Keep image content so chat template can emit vision tokens. + mm_parts.append( + {"type": "image", "image": part.image_url.url} + ) + content = mm_parts elif content is None: content = "" d: Dict[str, Any] = {"role": msg.role, "content": content} @@ -721,20 +778,25 @@ async def _stream() -> AsyncIterator[bytes]: prompt_tokens += r.get("prompt_tokens", 0) completion_tokens += r.get("completion_tokens", 0) - return ORJSONResponse( - { - "id": _make_completion_id(), - "object": "text_completion", - "created": int(time.time()), - "model": model_name, - "choices": choices, - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - }, - } + payload = { + "id": _make_completion_id(), + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + _maybe_add_debug_timing( + payload, + result=results[-1] if results else {}, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) + return ORJSONResponse(payload) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except RuntimeError as e: @@ -960,27 +1022,32 @@ def _make_sse(delta: Dict[str, Any], finish: Optional[str] = None) -> bytes: if tool_calls_list: message["tool_calls"] = tool_calls_list - return ORJSONResponse( - { - "id": _make_chat_completion_id(), - "object": "chat.completion", - "created": int(time.time()), - "model": model_name, - "choices": [ - { - "index": 0, - "message": message, - "logprobs": None, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - }, - } + payload = { + "id": _make_chat_completion_id(), + "object": "chat.completion", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "message": message, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + _maybe_add_debug_timing( + payload, + result=r, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) + return ORJSONResponse(payload) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except RuntimeError as e: diff --git a/pymllm/tests/bench_w8a8_activation_quant.py b/pymllm/tests/bench_w8a8_activation_quant.py new file mode 100644 index 000000000..4a7c38826 --- /dev/null +++ b/pymllm/tests/bench_w8a8_activation_quant.py @@ -0,0 +1,107 @@ +"""Benchmark W8A8 activation quantization implementations. + +Covers: torch path (current) and (future) Triton kernel. +This script is reusable across phases. + +Usage: + python pymllm/tests/bench_w8a8_activation_quant.py +""" +from __future__ import annotations + +import time + +import torch + + +# --------------------------------------------------------------------------- +# Implementations +# --------------------------------------------------------------------------- + +def torch_per_token_quant_int8(x: torch.Tensor): + """Current torch-based activation quantization.""" + x_fp32 = x.to(torch.float32) + absmax = torch.clamp(x_fp32.abs().amax(dim=-1, keepdim=True), min=1e-10) + x_scale = absmax / 127.0 + x_q = torch.round(x_fp32 / x_scale).clamp(-128, 127).to(torch.int8) + return x_q.contiguous(), x_scale.contiguous() + + +def _try_load_triton_kernel(): + try: + from pymllm.quantization.kernels.int8_activation_triton import per_token_quant_int8 + return per_token_quant_int8 + except Exception: + return None + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + +def bench_fn(fn, args, warmup=5, repeat=20) -> float: + """Returns median latency in ms.""" + for _ in range(warmup): + fn(*args) + torch.cuda.synchronize() + + times = [] + for _ in range(repeat): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn(*args) + torch.cuda.synchronize() + t1 = time.perf_counter() + times.append((t1 - t0) * 1e3) + times.sort() + return times[len(times) // 2] + + +def run_benchmarks(): + device = "cuda" + + shapes = [ + # (M, K) — M=tokens, K=hidden_dim + (1, 2048), + (8, 2048), + (16, 2048), + (32, 2048), + (93, 2048), + (128, 2048), + (256, 2048), + ] + + backends = {} + backends["torch"] = torch_per_token_quant_int8 + + triton_fn = _try_load_triton_kernel() + if triton_fn is not None: + backends["triton"] = triton_fn + + print(f"{'Shape':>16s}", end="") + for name in backends: + print(f" {name:>12s}", end="") + print() + print("-" * (16 + 14 * len(backends))) + + for M, K in shapes: + x = torch.randn(M, K, device=device, dtype=torch.float16) + row_label = f"({M},{K})" + print(f"{row_label:>16s}", end="") + + for name, fn in backends.items(): + try: + ms = bench_fn(fn, (x,)) + print(f" {ms:>9.3f} ms", end="") + except Exception as e: + print(f" {'ERR':>12s}", end="") + + print() + + +if __name__ == "__main__": + print("=" * 50) + print("W8A8 Activation Quantization Benchmark") + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"SM: {torch.cuda.get_device_capability(0)}") + print("=" * 50) + run_benchmarks() diff --git a/pymllm/tests/test_bench_one_batch.py b/pymllm/tests/test_bench_one_batch.py new file mode 100644 index 000000000..cc2a87ae7 --- /dev/null +++ b/pymllm/tests/test_bench_one_batch.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import pytest +import torch + +from pymllm.configs.global_config import GlobalConfig +from pymllm.bench_one_batch import ( + BenchArgs, + BenchSetting, + generate_settings, + make_profile_trace_path, + make_synthetic_input_ids, + parse_args, + summarize_latencies, +) + + +@pytest.fixture(autouse=True) +def _reset_global_config(): + GlobalConfig.reset() + yield + GlobalConfig.reset() + + +def test_parse_args_accepts_server_config_and_list_bench_args(tmp_path): + model_dir = tmp_path / "model" + result_file = tmp_path / "bench.jsonl" + model_dir.mkdir() + + cfg, bench_args = parse_args( + [ + "--server.model_path", + str(model_dir), + "--server.dtype", + "float16", + "--quantization.method", + "compressed-tensors", + "--run-name", + "unit", + "--batch-size", + "1", + "4", + "--input-len", + "256", + "512", + "--output-len", + "8", + "16", + "--result-filename", + str(result_file), + "--profile-stage", + "decode", + "--profile-activities", + "CPU", + "GPU", + ] + ) + + assert cfg.server.model_path == model_dir + assert cfg.server.tokenizer_path == model_dir + assert cfg.server.dtype == "float16" + assert cfg.quantization.method == "compressed-tensors" + assert bench_args.run_name == "unit" + assert bench_args.batch_size == [1, 4] + assert bench_args.input_len == [256, 512] + assert bench_args.output_len == [8, 16] + assert bench_args.result_filename == result_file + assert bench_args.profile_stage == "decode" + assert bench_args.profile_activities == ["CPU", "GPU"] + + +def test_generate_settings_has_stable_batch_input_output_order(tmp_path): + args = BenchArgs( + batch_size=[1, 2], + input_len=[256, 512], + output_len=[8], + result_filename=tmp_path / "out.jsonl", + ) + + assert generate_settings(args) == [ + BenchSetting(batch_size=1, input_len=256, output_len=8), + BenchSetting(batch_size=1, input_len=512, output_len=8), + BenchSetting(batch_size=2, input_len=256, output_len=8), + BenchSetting(batch_size=2, input_len=512, output_len=8), + ] + + +def test_make_synthetic_input_ids_is_seeded_int32_and_vocab_capped(): + first = make_synthetic_input_ids( + batch_size=2, + input_len=4, + vocab_size=50_000, + seed=123, + device="cpu", + ) + second = make_synthetic_input_ids( + batch_size=2, + input_len=4, + vocab_size=50_000, + seed=123, + device="cpu", + ) + + assert first.shape == (2, 4) + assert first.dtype == torch.int32 + assert torch.equal(first, second) + assert int(first.min()) >= 0 + assert int(first.max()) < 10_000 + + +def test_summarize_latencies_matches_sglang_style_metrics(): + setting = BenchSetting(batch_size=2, input_len=256, output_len=4) + + result = summarize_latencies( + setting=setting, + prefill_latency=0.5, + decode_latencies=[0.1, 0.2, 0.3], + run_name="unit", + device="cuda", + dtype="torch.float16", + cuda_graph=True, + ) + + assert result["run_name"] == "unit" + assert result["batch_size"] == 2 + assert result["input_len"] == 256 + assert result["output_len"] == 4 + assert result["prefill_latency"] == 0.5 + assert result["prefill_throughput"] == pytest.approx(1024.0) + assert result["median_decode_latency"] == pytest.approx(0.2) + assert result["median_decode_throughput"] == pytest.approx(10.0) + assert result["total_latency"] == pytest.approx(1.1) + assert result["overall_throughput"] == pytest.approx((260 * 2) / 1.1) + assert result["device"] == "cuda" + assert result["dtype"] == "torch.float16" + assert result["cuda_graph"] is True + + +def test_make_profile_trace_path_is_deterministic_and_sanitized(tmp_path): + path = make_profile_trace_path( + output_dir=tmp_path, + prefix="pymllm_profile", + run_name="qwen3/vl w8a8", + setting=BenchSetting(batch_size=1, input_len=256, output_len=8), + stage="decode", + ) + + assert path.parent == tmp_path + assert path.name == "pymllm_profile_qwen3_vl_w8a8_bs1_in256_out8_decode.trace.json" diff --git a/pymllm/tests/test_compressed_tensors_config.py b/pymllm/tests/test_compressed_tensors_config.py new file mode 100644 index 000000000..2ece55d68 --- /dev/null +++ b/pymllm/tests/test_compressed_tensors_config.py @@ -0,0 +1,144 @@ +import copy +import json +import pytest + +from pymllm.executor.model_runner import ModelRunner +from pymllm.quantization import get_quantization_config, list_quantization_methods +from pymllm.quantization.methods.compressed_tensors import ( + CompressedTensorsConfig, + CompressedTensorsLinearMethod, +) + + +def _current_ct_config(): + return { + "quant_method": "compressed-tensors", + "format": "pack-quantized", + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 4, + "group_size": 32, + "strategy": "group", + "symmetric": True, + "actorder": None, + }, + }, + }, + "ignore": ["ignore_prefix"], + } + + +def _current_ct_w8a8_config(): + return { + "quant_method": "compressed-tensors", + "format": "int-quantized", + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 8, + "group_size": None, + "strategy": "channel", + "symmetric": True, + "dynamic": False, + "actorder": None, + "type": "int", + }, + "input_activations": { + "num_bits": 8, + "strategy": "token", + "symmetric": True, + "dynamic": True, + "type": "int", + }, + }, + }, + "ignore": ["ignore_prefix"], + } + + +def test_compressed_tensors_is_registered(): + assert "compressed-tensors" in list_quantization_methods() + assert get_quantization_config("compressed-tensors") is CompressedTensorsConfig + + +def test_from_config_parses_current_signature(): + config = CompressedTensorsConfig.from_config( + copy.deepcopy(_current_ct_config()) + ) + + assert config.quant_format == "pack-quantized" + assert config.weight_bits == 4 + assert config.group_size == 32 + assert config.symmetric is True + assert config.actorder is None + assert config.ignore == ["ignore_prefix"] + + +def test_from_config_parses_w8a8_signature(): + config = CompressedTensorsConfig.from_config( + copy.deepcopy(_current_ct_w8a8_config()) + ) + + assert config.quant_format == "int-quantized" + assert config.weight_bits == 8 + assert config.group_size is None + assert config.weight_strategy == "channel" + assert config.weight_type == "int" + assert config.symmetric is True + assert config.input_bits == 8 + assert config.input_strategy == "token" + assert config.input_dynamic is True + assert config.ignore == ["ignore_prefix"] + + +def test_load_quant_config_dict_unwraps_quantization_config_from_config_json( + tmp_path, +): + root_config = { + "architectures": ["Qwen3VLForConditionalGeneration"], + "quantization_config": copy.deepcopy(_current_ct_config()), + } + (tmp_path / "config.json").write_text(json.dumps(root_config)) + + loaded = ModelRunner._load_quant_config_dict(tmp_path) + + assert loaded == root_config["quantization_config"] + + +def test_get_quant_method_respects_ignore(): + config = CompressedTensorsConfig.from_config( + copy.deepcopy(_current_ct_config()) + ) + assert config.get_quant_method(layer=None, prefix="ignore_prefix.layer") is None + + method = config.get_quant_method( + layer=None, + prefix="model.language_model.layers.0.self_attn.q_proj", + ) + assert isinstance(method, CompressedTensorsLinearMethod) + +def test_get_quant_method_rejects_unsupported_signature(): + checkpoint_config = copy.deepcopy(_current_ct_config()) + checkpoint_config["config_groups"]["group_0"]["weights"]["group_size"] = 128 + + config = CompressedTensorsConfig.from_config(checkpoint_config) + + with pytest.raises(ValueError, match="group_size"): + config.get_quant_method( + layer=None, + prefix="model.language_model.layers.0.self_attn.q_proj", + ) + + +def test_get_quant_method_accepts_w8a8_signature(): + config = CompressedTensorsConfig.from_config( + copy.deepcopy(_current_ct_w8a8_config()) + ) + method = config.get_quant_method( + layer=None, + prefix="model.language_model.layers.0.self_attn.q_proj", + ) + assert isinstance(method, CompressedTensorsLinearMethod) diff --git a/pymllm/tests/test_compressed_tensors_runtime.py b/pymllm/tests/test_compressed_tensors_runtime.py new file mode 100644 index 000000000..7c8e69824 --- /dev/null +++ b/pymllm/tests/test_compressed_tensors_runtime.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import pytest +import torch +from torch import nn + +import pymllm.quantization.methods.compressed_tensors as ct + + +def _current_ct_config() -> dict: + return { + "quant_method": "compressed-tensors", + "format": "pack-quantized", + "ignore": ["lm_head"], + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 4, + "group_size": 32, + "strategy": "group", + "symmetric": True, + "actorder": None, + "type": "int", + }, + } + }, + } + + +def _current_ct_w8a8_config() -> dict: + return { + "quant_method": "compressed-tensors", + "format": "int-quantized", + "ignore": ["lm_head"], + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 8, + "group_size": None, + "strategy": "channel", + "symmetric": True, + "dynamic": False, + "actorder": None, + "type": "int", + }, + "input_activations": { + "num_bits": 8, + "strategy": "token", + "symmetric": True, + "dynamic": True, + "type": "int", + }, + } + }, + } + + +class _DummyLayer(nn.Module): + pass + + +def _build_quant_method() -> ct.CompressedTensorsLinearMethod: + cfg = ct.CompressedTensorsConfig.from_config(_current_ct_config()) + qm = cfg.get_quant_method( + layer=None, + prefix="model.language_model.layers.0.self_attn.q_proj", + ) + assert isinstance(qm, ct.CompressedTensorsLinearMethod) + return qm + + +def _build_quant_method_w8a8() -> ct.CompressedTensorsLinearMethod: + cfg = ct.CompressedTensorsConfig.from_config(_current_ct_w8a8_config()) + qm = cfg.get_quant_method( + layer=None, + prefix="model.language_model.layers.0.self_attn.q_proj", + ) + assert isinstance(qm, ct.CompressedTensorsLinearMethod) + return qm + + +def _weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: + param.data.copy_(loaded_weight) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_create_weights_registers_checkpoint_parameter_names(): + layer = _DummyLayer() + qm = _build_quant_method() + + with torch.device("cuda"): + qm.create_weights( + layer=layer, + input_size_per_partition=2048, + output_partition_sizes=[2048], + input_size=2048, + output_size=2048, + params_dtype=torch.bfloat16, + weight_loader=_weight_loader, + ) + + assert {"weight_packed", "weight_scale", "weight_shape"} <= set( + layer._parameters + ) + assert tuple(layer.weight_packed.shape) == (2048, 256) + assert tuple(layer.weight_scale.shape) == (2048, 64) + assert tuple(layer.weight_shape.shape) == (2,) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_process_and_apply_use_gptq_repack_and_uint4b8( + monkeypatch: pytest.MonkeyPatch, +): + layer = _DummyLayer() + qm = _build_quant_method() + + with torch.device("cuda"): + qm.create_weights( + layer=layer, + input_size_per_partition=2048, + output_partition_sizes=[2048], + input_size=2048, + output_size=2048, + params_dtype=torch.bfloat16, + weight_loader=_weight_loader, + ) + + with torch.no_grad(): + layer.weight_packed.copy_( + torch.arange( + layer.weight_packed.numel(), + device="cuda", + dtype=torch.int32, + ).reshape_as(layer.weight_packed) + ) + layer.weight_scale.fill_(1) + layer.weight_shape.copy_( + torch.tensor([2048, 2048], device="cuda", dtype=torch.int64) + ) + + repack_calls: dict[str, object] = {} + scale_calls: dict[str, object] = {} + workspace = torch.zeros(1, dtype=torch.int32, device="cuda") + empty_tensors: list[torch.Tensor] = [] + + monkeypatch.setattr(ct, "verify_marlin_supports_shape", lambda **_: None) + monkeypatch.setattr( + ct, + "marlin_make_workspace", + lambda device: workspace, + ) + monkeypatch.setattr( + ct, + "marlin_make_empty_g_idx", + lambda device: empty_tensors.append( + torch.empty(0, dtype=torch.int32, device=device) + ) + or empty_tensors[-1], + ) + monkeypatch.setattr( + ct, + "gptq_marlin_repack", + lambda b_q_weight, perm, size_k, size_n, num_bits: repack_calls.update( + { + "b_q_weight": b_q_weight, + "perm": perm, + "size_k": size_k, + "size_n": size_n, + "num_bits": num_bits, + } + ) + or torch.zeros( + (size_k // 16, size_n * 16 // (32 // num_bits)), + dtype=torch.int32, + device=b_q_weight.device, + ), + ) + monkeypatch.setattr( + ct, + "marlin_permute_scales", + lambda s, size_k, size_n, group_size: scale_calls.update( + { + "s": s, + "size_k": size_k, + "size_n": size_n, + "group_size": group_size, + } + ) + or torch.zeros( + (size_k // group_size, size_n), + dtype=s.dtype, + device=s.device, + ), + ) + + calls: dict[str, object] = {} + + def fake_gemm(**kwargs): + calls.update(kwargs) + return torch.zeros( + (kwargs["size_m"], kwargs["size_n"]), + dtype=kwargs["a"].dtype, + device=kwargs["a"].device, + ) + + monkeypatch.setattr(ct, "gptq_marlin_gemm", fake_gemm) + + qm.process_weights_after_loading(layer) + x = torch.randn(2, 2048, device="cuda", dtype=torch.bfloat16) + out = qm.apply(layer, x) + + assert out.shape == (2, 2048) + assert repack_calls["perm"] is layer.g_idx_sort_indices + assert repack_calls["size_k"] == 2048 + assert repack_calls["size_n"] == 2048 + assert repack_calls["num_bits"] == 4 + assert scale_calls["size_k"] == 2048 + assert scale_calls["size_n"] == 2048 + assert scale_calls["group_size"] == 32 + assert calls["workspace"] is workspace + assert calls["b_zeros"] is layer.weight_zero_point + assert calls["g_idx"] is layer.weight_g_idx + assert calls["perm"] is layer.g_idx_sort_indices + assert calls["b_q_type_id"] == ct.SCALAR_TYPE_UINT4B8.id + assert calls["b_q_weight"] is layer.weight_packed + + +def test_w8a8_create_weights_registers_weight_and_scale(): + layer = _DummyLayer() + qm = _build_quant_method_w8a8() + + qm.create_weights( + layer=layer, + input_size_per_partition=64, + output_partition_sizes=[96], + input_size=64, + output_size=96, + params_dtype=torch.float16, + weight_loader=_weight_loader, + ) + + assert {"weight", "weight_scale"} <= set(layer._parameters) + assert tuple(layer.weight.shape) == (96, 64) + assert layer.weight.dtype == torch.int8 + assert tuple(layer.weight_scale.shape) == (96, 1) + assert layer.weight_scale.dtype == torch.float32 + + +def test_w8a8_process_weights_transposes_and_flattens_scales(): + layer = _DummyLayer() + qm = _build_quant_method_w8a8() + qm.create_weights( + layer=layer, + input_size_per_partition=32, + output_partition_sizes=[48], + input_size=32, + output_size=48, + params_dtype=torch.float16, + weight_loader=_weight_loader, + ) + + with torch.no_grad(): + layer.weight.copy_( + torch.arange(layer.weight.numel(), dtype=torch.int8).reshape_as(layer.weight) + ) + layer.weight_scale.copy_( + torch.arange(1, 49, dtype=torch.float32).reshape(48, 1) / 100.0 + ) + + qm.process_weights_after_loading(layer) + + assert tuple(layer.weight.shape) == (32, 48) + # Weight is stored as (K, N) column-major for CUTLASS: stride(0)==1 + assert layer.weight.stride(0) == 1, "weight should be column-major for CUTLASS" + assert tuple(layer.weight_scale.shape) == (48,) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_w8a8_apply_matches_reference_for_large_m(): + layer = _DummyLayer() + qm = _build_quant_method_w8a8() + + with torch.device("cuda"): + qm.create_weights( + layer=layer, + input_size_per_partition=64, + output_partition_sizes=[128], + input_size=64, + output_size=128, + params_dtype=torch.float16, + weight_loader=_weight_loader, + ) + + with torch.no_grad(): + layer.weight.copy_( + torch.randint(-127, 128, layer.weight.shape, device="cuda", dtype=torch.int8) + ) + layer.weight_scale.copy_( + torch.rand(layer.weight_scale.shape, device="cuda", dtype=torch.float32) + + 1e-3 + ) + qm.process_weights_after_loading(layer) + + x = torch.randn(32, 64, device="cuda", dtype=torch.float16) + bias = torch.randn(128, device="cuda", dtype=torch.float16) + out = qm.apply(layer, x, bias) + + x_q, x_scale = ct._per_token_quant_int8(x) + ref_i32 = torch._int_mm(x_q, layer.weight).to(torch.float32) + ref = (ref_i32 * x_scale * layer.weight_scale.view(1, -1)).to(x.dtype) + bias + + assert out.shape == (32, 128) + assert torch.allclose(out, ref, atol=2e-1, rtol=2e-1) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_w8a8_apply_supports_small_m_by_padding(): + layer = _DummyLayer() + qm = _build_quant_method_w8a8() + + with torch.device("cuda"): + qm.create_weights( + layer=layer, + input_size_per_partition=64, + output_partition_sizes=[64], + input_size=64, + output_size=64, + params_dtype=torch.float16, + weight_loader=_weight_loader, + ) + + with torch.no_grad(): + layer.weight.copy_( + torch.randint(-127, 128, layer.weight.shape, device="cuda", dtype=torch.int8) + ) + layer.weight_scale.fill_(0.01) + qm.process_weights_after_loading(layer) + + x = torch.randn(2, 64, device="cuda", dtype=torch.float16) + out = qm.apply(layer, x) + + assert out.shape == (2, 64) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_w8a8_apply_uses_triton_quant_and_torch_int_mm( + monkeypatch: pytest.MonkeyPatch, +): + """Verify the W8A8 forward path uses Triton activation quant + torch._int_mm.""" + layer = _DummyLayer() + qm = _build_quant_method_w8a8() + + with torch.device("cuda"): + qm.create_weights( + layer=layer, + input_size_per_partition=64, + output_partition_sizes=[64], + input_size=64, + output_size=64, + params_dtype=torch.float16, + weight_loader=_weight_loader, + ) + + with torch.no_grad(): + layer.weight.copy_( + torch.randint(-127, 128, layer.weight.shape, device="cuda", dtype=torch.int8) + ) + layer.weight_scale.fill_(0.01) + qm.process_weights_after_loading(layer) + + # Track that Triton quantization is called via the cached function ref + triton_quant_calls: list[tuple] = [] + original_triton_quant = ct._get_triton_quant() + + def tracked_triton_quant(x, **kwargs): + triton_quant_calls.append(tuple(x.shape)) + return original_triton_quant(x, **kwargs) + + monkeypatch.setattr(ct, "_triton_quant_fn", tracked_triton_quant) + + x = torch.randn(2, 64, device="cuda", dtype=torch.float16) + bias = torch.randn(64, device="cuda", dtype=torch.float16) + out = qm.apply(layer, x, bias) + + assert out.shape == (2, 64) + assert len(triton_quant_calls) == 1, "Triton quant should be called exactly once" + assert triton_quant_calls[0] == (2, 64) diff --git a/pymllm/tests/test_linear_merged.py b/pymllm/tests/test_linear_merged.py new file mode 100644 index 000000000..199434c03 --- /dev/null +++ b/pymllm/tests/test_linear_merged.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import torch +from torch import nn +from torch.nn import Parameter + +from pymllm.layers.linear import MergedLinear +from pymllm.layers.quantize_base import LinearMethodBase +from pymllm.layers.utils import set_weight_attrs +from pymllm.quantization.methods.compressed_tensors import CompressedTensorsConfig + + +def _w8a8_config() -> CompressedTensorsConfig: + return CompressedTensorsConfig.from_config( + { + "quant_method": "compressed-tensors", + "format": "int-quantized", + "ignore": ["lm_head"], + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 8, + "group_size": None, + "strategy": "channel", + "symmetric": True, + "dynamic": False, + "actorder": None, + "type": "int", + }, + "input_activations": { + "num_bits": 8, + "strategy": "token", + "symmetric": True, + "dynamic": True, + "type": "int", + }, + } + }, + } + ) + + +def test_merged_linear_weight_loader_stacks_w8a8_qkv_weight_and_scale(): + qm = _w8a8_config().get_quant_method( + layer=None, + prefix="model.layers.0.self_attn.qkv_proj", + ) + layer = MergedLinear(4, [6, 2, 2], bias=False, quant_method=qm) + + q = torch.full((6, 4), 1, dtype=torch.int8) + k = torch.full((2, 4), 2, dtype=torch.int8) + v = torch.full((2, 4), 3, dtype=torch.int8) + + layer.weight_loader(layer.weight, q, "q") + layer.weight_loader(layer.weight, k, "k") + layer.weight_loader(layer.weight, v, "v") + + assert torch.equal(layer.weight[:6], q) + assert torch.equal(layer.weight[6:8], k) + assert torch.equal(layer.weight[8:10], v) + + q_scale = torch.full((6, 1), 0.1, dtype=torch.float32) + k_scale = torch.full((2, 1), 0.2, dtype=torch.float32) + v_scale = torch.full((2, 1), 0.3, dtype=torch.float32) + + layer.weight_loader(layer.weight_scale, q_scale, "q") + layer.weight_loader(layer.weight_scale, k_scale, "k") + layer.weight_loader(layer.weight_scale, v_scale, "v") + + torch.testing.assert_close(layer.weight_scale[:6], q_scale) + torch.testing.assert_close(layer.weight_scale[6:8], k_scale) + torch.testing.assert_close(layer.weight_scale[8:10], v_scale) + + +class _PackedOutputMethod(LinearMethodBase): + def create_weights( + self, + layer: nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del input_size, output_size, params_dtype + qweight = Parameter( + torch.empty( + input_size_per_partition, + sum(output_partition_sizes) // 8, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(qweight, {"output_dim": 1, **extra_weight_attrs}) + layer.register_parameter("qweight", qweight) + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = sum(output_partition_sizes) + + def apply( + self, + layer: nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + del layer, bias + return x + + +def test_merged_linear_weight_loader_stacks_packed_output_dim_by_loaded_width(): + layer = MergedLinear(4, [16, 8, 8], bias=False, quant_method=_PackedOutputMethod()) + + q = torch.full((4, 2), 1, dtype=torch.int32) + k = torch.full((4, 1), 2, dtype=torch.int32) + v = torch.full((4, 1), 3, dtype=torch.int32) + + layer.weight_loader(layer.qweight, q, "q") + layer.weight_loader(layer.qweight, k, "k") + layer.weight_loader(layer.qweight, v, "v") + + assert torch.equal(layer.qweight[:, :2], q) + assert torch.equal(layer.qweight[:, 2:3], k) + assert torch.equal(layer.qweight[:, 3:4], v) diff --git a/pymllm/tests/test_qwen3_forward_timing.py b/pymllm/tests/test_qwen3_forward_timing.py new file mode 100644 index 000000000..bea886301 --- /dev/null +++ b/pymllm/tests/test_qwen3_forward_timing.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch + +from pymllm.executor.model_runner import LogitsProcessorOutput +from pymllm.models.qwen3 import Qwen3ForCausalLM + + +class _Mode: + def __init__(self, *, is_extend: bool, is_decode: bool): + self._is_extend = is_extend + self._is_decode = is_decode + + def is_extend(self) -> bool: + return self._is_extend + + def is_decode(self) -> bool: + return self._is_decode + + +def _make_config() -> SimpleNamespace: + return SimpleNamespace( + hidden_size=8, + intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + rope_theta=1_000_000.0, + rms_norm_eps=1e-6, + max_position_embeddings=128, + attention_bias=False, + vocab_size=32, + tie_word_embeddings=False, + hidden_act="silu", + ) + + +def test_forward_extend_sets_prefill_timing_and_prunes_hidden_states(monkeypatch): + cfg = _make_config() + model = Qwen3ForCausalLM(cfg) + + def fake_forward(input_ids, positions, forward_batch, input_embeds=None): + del positions, forward_batch, input_embeds + return torch.ones((input_ids.shape[0], cfg.hidden_size), dtype=torch.float32) + + monkeypatch.setattr(model.model, "forward", fake_forward) + + fb = SimpleNamespace( + forward_mode=_Mode(is_extend=True, is_decode=False), + extend_start_loc=torch.tensor([0, 3], dtype=torch.int64), + extend_seq_lens=torch.tensor([3, 2], dtype=torch.int64), + llm_prefill_ms=None, + llm_decode_ms=None, + ) + + out = model.forward( + input_ids=torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64), + positions=torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64), + forward_batch=fb, + ) + + assert isinstance(out, LogitsProcessorOutput) + assert out.next_token_logits.shape == (2, cfg.vocab_size) + assert fb.llm_prefill_ms is not None + assert fb.llm_prefill_ms >= 0.0 + assert fb.llm_decode_ms is None + + +def test_forward_decode_sets_decode_timing(monkeypatch): + cfg = _make_config() + model = Qwen3ForCausalLM(cfg) + + def fake_forward(input_ids, positions, forward_batch, input_embeds=None): + del positions, forward_batch, input_embeds + return torch.ones((input_ids.shape[0], cfg.hidden_size), dtype=torch.float32) + + monkeypatch.setattr(model.model, "forward", fake_forward) + + fb = SimpleNamespace( + forward_mode=_Mode(is_extend=False, is_decode=True), + llm_prefill_ms=None, + llm_decode_ms=None, + ) + + out = model.forward( + input_ids=torch.tensor([7, 8], dtype=torch.int64), + positions=torch.tensor([11, 12], dtype=torch.int64), + forward_batch=fb, + ) + + assert isinstance(out, LogitsProcessorOutput) + assert out.next_token_logits.shape == (2, cfg.vocab_size) + assert fb.llm_prefill_ms is None + assert fb.llm_decode_ms is not None + assert fb.llm_decode_ms >= 0.0 diff --git a/pymllm/tests/test_qwen3_model_registry.py b/pymllm/tests/test_qwen3_model_registry.py new file mode 100644 index 000000000..47504c97a --- /dev/null +++ b/pymllm/tests/test_qwen3_model_registry.py @@ -0,0 +1,7 @@ +from pymllm.models import get_model_class + + +def test_registry_resolves_qwen3_causallm(): + cls = get_model_class("Qwen3ForCausalLM") + assert cls is not None + assert cls.__name__ == "Qwen3ForCausalLM" diff --git a/pymllm/tests/test_qwen3_residual_carry.py b/pymllm/tests/test_qwen3_residual_carry.py new file mode 100644 index 000000000..8f925dd6f --- /dev/null +++ b/pymllm/tests/test_qwen3_residual_carry.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch +from torch import nn + +from pymllm.models.qwen3 import Qwen3DecoderLayer +from pymllm.models.qwen3 import Qwen3Model + + +class _RecordingNorm(nn.Module): + def __init__(self, residual_offset: float): + super().__init__() + self.residual_offset = residual_offset + self.seen_residual: list[bool] = [] + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ): + self.seen_residual.append(residual is not None) + if residual is None: + return x + 1.0 + residual_out = x + residual + return residual_out + self.residual_offset, residual_out + + +class _AttentionAdd(nn.Module): + def forward(self, positions, hidden_states, forward_batch): + del positions, forward_batch + return hidden_states + 3.0 + + +class _MLPAdd(nn.Module): + def forward(self, hidden_states): + return hidden_states + 4.0 + + +class _CarryLayer(nn.Module): + def forward(self, positions, hidden_states, forward_batch, **kwargs): + del positions, forward_batch, kwargs + return hidden_states + 10.0, hidden_states + 100.0 + + +class _TensorLayer(nn.Module): + def forward(self, positions, hidden_states, forward_batch): + del positions, forward_batch + return hidden_states * 2.0 + + +def test_qwen3_decoder_layer_returns_residual_carry_and_fuses_post_attn_norm(): + layer = Qwen3DecoderLayer( + hidden_size=2, + num_heads=1, + num_kv_heads=1, + head_dim=2, + intermediate_size=4, + hidden_act="silu", + attention_bias=False, + layer_id=0, + ) + layer.input_layernorm = _RecordingNorm(residual_offset=10.0) + layer.post_attention_layernorm = _RecordingNorm(residual_offset=20.0) + layer.self_attn = _AttentionAdd() + layer.mlp = _MLPAdd() + + hidden_states = torch.tensor([[1.0, 2.0]]) + + next_hidden, residual = layer( + positions=torch.tensor([0]), + hidden_states=hidden_states, + forward_batch=SimpleNamespace(), + residual=None, + ) + + assert layer.input_layernorm.seen_residual == [False] + assert layer.post_attention_layernorm.seen_residual == [True] + torch.testing.assert_close(residual, torch.tensor([[6.0, 8.0]])) + torch.testing.assert_close(next_hidden, torch.tensor([[30.0, 32.0]])) + + +def test_qwen3_model_materializes_residual_before_tensor_returning_layer(): + cfg = SimpleNamespace( + hidden_size=2, + intermediate_size=4, + num_hidden_layers=2, + num_attention_heads=1, + num_key_value_heads=1, + head_dim=2, + rope_theta=1_000_000.0, + rms_norm_eps=1e-6, + max_position_embeddings=32, + attention_bias=False, + vocab_size=8, + hidden_act="silu", + ) + model = Qwen3Model(cfg) + model.layers = nn.ModuleList([_CarryLayer(), _TensorLayer()]) + model.norm = nn.Identity() + + input_embeds = torch.tensor( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + dtype=torch.float32, + ) + + hidden_states = model( + input_ids=torch.tensor([0, 1, 2], dtype=torch.int64), + positions=torch.tensor([0, 1, 2], dtype=torch.int64), + forward_batch=SimpleNamespace(), + input_embeds=input_embeds, + ) + + torch.testing.assert_close( + hidden_states, + (input_embeds + 10.0 + input_embeds + 100.0) * 2.0, + ) diff --git a/pymllm/tests/test_qwen3_vl_deepstack.py b/pymllm/tests/test_qwen3_vl_deepstack.py new file mode 100644 index 000000000..c36bacf39 --- /dev/null +++ b/pymllm/tests/test_qwen3_vl_deepstack.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from pymllm.models.qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLTextModel, + Qwen3VLVisionModel, + _compute_cu_seqlens_from_grid, +) + + +class _AddLayer(nn.Module): + def __init__(self, value: float): + super().__init__() + self.value = value + + def forward(self, positions, hidden_states, forward_batch, **kwargs): + del positions, forward_batch, kwargs + return hidden_states + self.value + + +class _CarryLayer(nn.Module): + def forward(self, positions, hidden_states, forward_batch, **kwargs): + del positions, forward_batch, kwargs + return hidden_states + 10.0, hidden_states + 100.0 + + +class _TensorLayer(nn.Module): + def forward(self, positions, hidden_states, forward_batch): + del positions, forward_batch + return hidden_states * 2.0 + + +class _FinalNorm(nn.Module): + def __init__(self): + super().__init__() + self.seen_residual = None + + def forward(self, hidden_states, residual=None): + self.seen_residual = residual + if residual is None: + return hidden_states + return hidden_states + residual + + +class _Mode: + def is_extend(self) -> bool: + return True + + def is_decode(self) -> bool: + return False + + +class _FakeVisual(nn.Module): + def forward(self, pixel_values, grid_thw): + del pixel_values, grid_thw + return torch.ones((1, 2), dtype=torch.float32) + + +def _make_vl_config() -> SimpleNamespace: + text_config = SimpleNamespace( + hidden_size=2, + intermediate_size=4, + num_hidden_layers=1, + num_attention_heads=1, + num_key_value_heads=1, + head_dim=2, + rope_theta=1_000_000.0, + rms_norm_eps=1e-6, + rope_scaling={"mrope_section": [1, 1, 0], "mrope_interleaved": True}, + max_position_embeddings=32, + vocab_size=8, + ) + vision_config = SimpleNamespace( + depth=0, + hidden_size=2, + intermediate_size=4, + num_heads=1, + in_channels=3, + patch_size=1, + spatial_merge_size=1, + temporal_patch_size=1, + out_hidden_size=2, + num_position_embeddings=4, + deepstack_visual_indexes=[], + ) + return SimpleNamespace( + text_config=text_config, + vision_config=vision_config, + image_token_id=5, + video_token_id=6, + vision_start_token_id=4, + tie_word_embeddings=False, + ) + + +def test_text_model_adds_deepstack_after_decoder_layer(): + model = Qwen3VLTextModel( + vocab_size=8, + hidden_size=2, + intermediate_size=4, + num_hidden_layers=1, + num_attention_heads=1, + num_key_value_heads=1, + head_dim=2, + ) + model.layers = nn.ModuleList([_AddLayer(10.0)]) + model.norm = nn.Identity() + + input_embeds = torch.tensor( + [[1.0, 2.0], [3.0, 4.0]], + dtype=torch.float32, + ) + input_deepstack_embeds = torch.tensor( + [[0.5, 1.5], [2.5, 3.5]], + dtype=torch.float32, + ) + + hidden_states = model( + input_ids=torch.tensor([0, 1], dtype=torch.int64), + positions=torch.zeros((3, 2), dtype=torch.int64), + forward_batch=SimpleNamespace(), + input_embeds=input_embeds, + input_deepstack_embeds=input_deepstack_embeds, + ) + + torch.testing.assert_close( + hidden_states, + input_embeds + 10.0 + input_deepstack_embeds, + ) + + +def test_text_model_deepstack_resets_residual_carry_before_injection(): + model = Qwen3VLTextModel( + vocab_size=8, + hidden_size=2, + intermediate_size=4, + num_hidden_layers=1, + num_attention_heads=1, + num_key_value_heads=1, + head_dim=2, + ) + final_norm = _FinalNorm() + model.layers = nn.ModuleList([_CarryLayer()]) + model.norm = final_norm + + input_embeds = torch.tensor( + [[1.0, 2.0], [3.0, 4.0]], + dtype=torch.float32, + ) + input_deepstack_embeds = torch.tensor( + [[0.5, 1.5], [2.5, 3.5]], + dtype=torch.float32, + ) + + hidden_states = model( + input_ids=torch.tensor([0, 1], dtype=torch.int64), + positions=torch.zeros((3, 2), dtype=torch.int64), + forward_batch=SimpleNamespace(), + input_embeds=input_embeds, + input_deepstack_embeds=input_deepstack_embeds, + ) + + assert final_norm.seen_residual is None + torch.testing.assert_close( + hidden_states, + input_embeds + 10.0 + input_embeds + 100.0 + input_deepstack_embeds, + ) + + +def test_text_model_materializes_residual_before_tensor_returning_layer(): + model = Qwen3VLTextModel( + vocab_size=8, + hidden_size=2, + intermediate_size=4, + num_hidden_layers=2, + num_attention_heads=1, + num_key_value_heads=1, + head_dim=2, + ) + model.layers = nn.ModuleList([_CarryLayer(), _TensorLayer()]) + model.norm = nn.Identity() + + input_embeds = torch.tensor( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + dtype=torch.float32, + ) + + hidden_states = model( + input_ids=torch.tensor([0, 1, 2], dtype=torch.int64), + positions=torch.zeros((3, 3), dtype=torch.int64), + forward_batch=SimpleNamespace(), + input_embeds=input_embeds, + ) + + torch.testing.assert_close( + hidden_states, + (input_embeds + 10.0 + input_embeds + 100.0) * 2.0, + ) + + +def test_forward_rejects_mismatched_image_token_and_feature_counts(): + model = Qwen3VLForConditionalGeneration(_make_vl_config()) + model.visual = _FakeVisual() + + forward_batch = SimpleNamespace( + forward_mode=_Mode(), + batch_size=1, + extend_start_loc=torch.tensor([0], dtype=torch.int64), + extend_seq_lens=torch.tensor([5], dtype=torch.int64), + pixel_values=torch.zeros((1, 3), dtype=torch.float32), + image_grid_thw=torch.tensor([[1, 1, 2]], dtype=torch.int64), + ) + + with pytest.raises( + ValueError, + match="Image features and image tokens do not match", + ): + model( + input_ids=torch.tensor([1, 4, 5, 5, 2], dtype=torch.int64), + positions=torch.arange(5, dtype=torch.int64), + forward_batch=forward_batch, + ) + + +def test_vision_interpolation_indices_match_sglang_hf(): + model = Qwen3VLVisionModel( + depth=0, + hidden_size=2, + intermediate_size=4, + num_heads=1, + in_channels=3, + patch_size=1, + spatial_merge_size=1, + temporal_patch_size=1, + out_hidden_size=2, + num_position_embeddings=16, + deepstack_visual_indexes=[], + ) + + np.testing.assert_allclose( + model._get_interpolation_indices(3), + np.linspace(0, 3, 3, dtype=np.float32), + ) + + +def test_vision_cu_seqlens_expands_temporal_frames_like_sglang_hf(): + cu_seqlens = _compute_cu_seqlens_from_grid( + torch.tensor([[2, 3, 5], [1, 2, 2]], dtype=torch.int64) + ) + + torch.testing.assert_close( + cu_seqlens, + torch.tensor([0, 15, 30, 34], dtype=torch.int32), + ) diff --git a/pymllm/tests/test_qwen3_vl_weight_loading.py b/pymllm/tests/test_qwen3_vl_weight_loading.py new file mode 100644 index 000000000..e4ea6ab53 --- /dev/null +++ b/pymllm/tests/test_qwen3_vl_weight_loading.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch + +from pymllm.models.qwen3_vl import Qwen3VLForConditionalGeneration +from pymllm.quantization.methods.compressed_tensors import CompressedTensorsConfig + + +def _make_vl_config() -> SimpleNamespace: + text_config = SimpleNamespace( + hidden_size=8, + intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + rope_theta=1_000_000.0, + rms_norm_eps=1e-6, + rope_scaling={"mrope_section": [2, 1, 1], "mrope_interleaved": True}, + max_position_embeddings=32, + vocab_size=32, + ) + return SimpleNamespace( + text_config=text_config, + vision_config=None, + image_token_id=5, + video_token_id=6, + vision_start_token_id=4, + tie_word_embeddings=False, + ) + + +def _make_w8a8_config() -> CompressedTensorsConfig: + return CompressedTensorsConfig.from_config( + { + "quant_method": "compressed-tensors", + "format": "int-quantized", + "ignore": ["lm_head"], + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 8, + "group_size": None, + "strategy": "channel", + "symmetric": True, + "dynamic": False, + "actorder": None, + "type": "int", + }, + "input_activations": { + "num_bits": 8, + "strategy": "token", + "symmetric": True, + "dynamic": True, + "type": "int", + }, + } + }, + } + ) + + +def _int8(shape: tuple[int, ...], value: int) -> torch.Tensor: + return torch.full(shape, value, dtype=torch.int8) + + +def test_quantized_vl_text_loads_fused_qkv_and_gate_up_weight_and_scale(): + cfg = _make_vl_config() + text_cfg = cfg.text_config + model = Qwen3VLForConditionalGeneration(cfg, quant_config=_make_w8a8_config()) + + layer0 = model.model.layers[0] + assert layer0.self_attn.use_fused_qkv + assert layer0.self_attn.qkv_proj is not None + assert layer0.self_attn.q_proj is None + assert layer0.self_attn.k_proj is None + assert layer0.self_attn.v_proj is None + assert layer0.mlp.use_fused_gate_up_proj + assert layer0.mlp.gate_up_proj is not None + assert layer0.mlp.gate_proj is None + assert layer0.mlp.up_proj is None + + q_size = text_cfg.num_attention_heads * text_cfg.head_dim + kv_size = text_cfg.num_key_value_heads * text_cfg.head_dim + hidden = text_cfg.hidden_size + inter = text_cfg.intermediate_size + + weights = { + "model.layers.0.self_attn.q_proj.weight": _int8((q_size, hidden), 1), + "model.layers.0.self_attn.k_proj.weight": _int8((kv_size, hidden), 2), + "model.layers.0.self_attn.v_proj.weight": _int8((kv_size, hidden), 3), + "model.layers.0.self_attn.q_proj.weight_scale": torch.full((q_size, 1), 0.1), + "model.layers.0.self_attn.k_proj.weight_scale": torch.full((kv_size, 1), 0.2), + "model.layers.0.self_attn.v_proj.weight_scale": torch.full((kv_size, 1), 0.3), + "model.layers.0.mlp.gate_proj.weight": _int8((inter, hidden), 4), + "model.layers.0.mlp.up_proj.weight": _int8((inter, hidden), 5), + "model.layers.0.mlp.gate_proj.weight_scale": torch.full((inter, 1), 0.4), + "model.layers.0.mlp.up_proj.weight_scale": torch.full((inter, 1), 0.5), + } + + model.load_weights(weights.items()) + + qkv = layer0.self_attn.qkv_proj + assert torch.equal(qkv.weight[:q_size], weights["model.layers.0.self_attn.q_proj.weight"]) + assert torch.equal( + qkv.weight[q_size : q_size + kv_size], + weights["model.layers.0.self_attn.k_proj.weight"], + ) + assert torch.equal( + qkv.weight[q_size + kv_size : q_size + 2 * kv_size], + weights["model.layers.0.self_attn.v_proj.weight"], + ) + torch.testing.assert_close( + qkv.weight_scale[:q_size], + weights["model.layers.0.self_attn.q_proj.weight_scale"], + ) + torch.testing.assert_close( + qkv.weight_scale[q_size : q_size + kv_size], + weights["model.layers.0.self_attn.k_proj.weight_scale"], + ) + torch.testing.assert_close( + qkv.weight_scale[q_size + kv_size : q_size + 2 * kv_size], + weights["model.layers.0.self_attn.v_proj.weight_scale"], + ) + + gate_up = layer0.mlp.gate_up_proj + assert torch.equal(gate_up.weight[:inter], weights["model.layers.0.mlp.gate_proj.weight"]) + assert torch.equal(gate_up.weight[inter : 2 * inter], weights["model.layers.0.mlp.up_proj.weight"]) + torch.testing.assert_close( + gate_up.weight_scale[:inter], + weights["model.layers.0.mlp.gate_proj.weight_scale"], + ) + torch.testing.assert_close( + gate_up.weight_scale[inter : 2 * inter], + weights["model.layers.0.mlp.up_proj.weight_scale"], + ) diff --git a/pymllm/tests/test_qwen3_weight_loading.py b/pymllm/tests/test_qwen3_weight_loading.py new file mode 100644 index 000000000..09447b850 --- /dev/null +++ b/pymllm/tests/test_qwen3_weight_loading.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import torch + +from pymllm.models.qwen3 import Qwen3ForCausalLM +from pymllm.quantization.methods.compressed_tensors import CompressedTensorsConfig + + +def _make_config() -> SimpleNamespace: + return SimpleNamespace( + hidden_size=8, + intermediate_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + rope_theta=1_000_000.0, + rms_norm_eps=1e-6, + max_position_embeddings=128, + attention_bias=False, + vocab_size=32, + tie_word_embeddings=False, + hidden_act="silu", + ) + + +def _make_weight(shape: tuple[int, ...], start: int) -> torch.Tensor: + numel = 1 + for s in shape: + numel *= s + return torch.arange(start, start + numel, dtype=torch.float32).reshape(shape) + + +def _make_int8_weight(shape: tuple[int, ...], value: int) -> torch.Tensor: + return torch.full(shape, value, dtype=torch.int8) + + +def _make_w8a8_config() -> CompressedTensorsConfig: + return CompressedTensorsConfig.from_config( + { + "quant_method": "compressed-tensors", + "format": "int-quantized", + "ignore": ["lm_head"], + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 8, + "group_size": None, + "strategy": "channel", + "symmetric": True, + "dynamic": False, + "actorder": None, + "type": "int", + }, + "input_activations": { + "num_bits": 8, + "strategy": "token", + "symmetric": True, + "dynamic": True, + "type": "int", + }, + } + }, + } + ) + + +def _build_language_weights(cfg: SimpleNamespace, layer_prefix: str = "model"): + q_size = cfg.num_attention_heads * cfg.head_dim + kv_size = cfg.num_key_value_heads * cfg.head_dim + hidden = cfg.hidden_size + inter = cfg.intermediate_size + + weights = { + f"{layer_prefix}.embed_tokens.weight": _make_weight((cfg.vocab_size, hidden), 1000), + f"{layer_prefix}.norm.weight": _make_weight((hidden,), 2000), + "lm_head.weight": _make_weight((cfg.vocab_size, hidden), 3000), + } + + for i in range(cfg.num_hidden_layers): + base = 10_000 * (i + 1) + p = f"{layer_prefix}.layers.{i}" + weights[f"{p}.input_layernorm.weight"] = _make_weight((hidden,), base + 1) + weights[f"{p}.post_attention_layernorm.weight"] = _make_weight((hidden,), base + 101) + + weights[f"{p}.self_attn.q_proj.weight"] = _make_weight((q_size, hidden), base + 1001) + weights[f"{p}.self_attn.k_proj.weight"] = _make_weight((kv_size, hidden), base + 2001) + weights[f"{p}.self_attn.v_proj.weight"] = _make_weight((kv_size, hidden), base + 3001) + weights[f"{p}.self_attn.o_proj.weight"] = _make_weight((hidden, q_size), base + 4001) + weights[f"{p}.self_attn.q_norm.weight"] = _make_weight((cfg.head_dim,), base + 5001) + weights[f"{p}.self_attn.k_norm.weight"] = _make_weight((cfg.head_dim,), base + 6001) + + weights[f"{p}.mlp.gate_proj.weight"] = _make_weight((inter, hidden), base + 7001) + weights[f"{p}.mlp.up_proj.weight"] = _make_weight((inter, hidden), base + 8001) + weights[f"{p}.mlp.down_proj.weight"] = _make_weight((hidden, inter), base + 9001) + + return weights + + +def test_load_weights_stacks_qkv_and_gate_up_from_model_prefix(): + cfg = _make_config() + model = Qwen3ForCausalLM(cfg) + + weights = _build_language_weights(cfg, layer_prefix="model") + model.load_weights(weights.items()) + + layer0 = model.model.layers[0] + q_size = cfg.num_attention_heads * cfg.head_dim + kv_size = cfg.num_key_value_heads * cfg.head_dim + + q = weights["model.layers.0.self_attn.q_proj.weight"] + k = weights["model.layers.0.self_attn.k_proj.weight"] + v = weights["model.layers.0.self_attn.v_proj.weight"] + qkv = layer0.self_attn.qkv_proj.weight.data + assert torch.equal(qkv[:q_size], q) + assert torch.equal(qkv[q_size : q_size + kv_size], k) + assert torch.equal(qkv[q_size + kv_size : q_size + 2 * kv_size], v) + + gate = weights["model.layers.0.mlp.gate_proj.weight"] + up = weights["model.layers.0.mlp.up_proj.weight"] + gate_up = layer0.mlp.gate_up_proj.weight.data + assert torch.equal(gate_up[: cfg.intermediate_size], gate) + assert torch.equal(gate_up[cfg.intermediate_size :], up) + + assert torch.equal(model.model.embed_tokens.weight.data, weights["model.embed_tokens.weight"]) + assert torch.equal(model.model.norm.weight.data, weights["model.norm.weight"]) + assert torch.equal(model.lm_head.weight.data, weights["lm_head.weight"]) + + +def test_load_weights_accepts_model_language_model_prefix(): + cfg = _make_config() + model = Qwen3ForCausalLM(cfg) + + weights = _build_language_weights(cfg, layer_prefix="model.language_model") + model.load_weights(weights.items()) + + layer1 = model.model.layers[1] + q = weights["model.language_model.layers.1.self_attn.q_proj.weight"] + k = weights["model.language_model.layers.1.self_attn.k_proj.weight"] + v = weights["model.language_model.layers.1.self_attn.v_proj.weight"] + + q_size = cfg.num_attention_heads * cfg.head_dim + kv_size = cfg.num_key_value_heads * cfg.head_dim + qkv = layer1.self_attn.qkv_proj.weight.data + + assert torch.equal(qkv[:q_size], q) + assert torch.equal(qkv[q_size : q_size + kv_size], k) + assert torch.equal(qkv[q_size + kv_size : q_size + 2 * kv_size], v) + + +def test_quantized_load_weights_stacks_qkv_and_gate_up_weight_and_scale(): + cfg = _make_config() + model = Qwen3ForCausalLM(cfg, quant_config=_make_w8a8_config()) + + layer0 = model.model.layers[0] + assert layer0.self_attn.use_fused_qkv + assert layer0.self_attn.qkv_proj is not None + assert layer0.self_attn.q_proj is None + assert layer0.self_attn.k_proj is None + assert layer0.self_attn.v_proj is None + assert layer0.mlp.use_fused_gate_up_proj + assert layer0.mlp.gate_up_proj is not None + assert layer0.mlp.gate_proj is None + assert layer0.mlp.up_proj is None + + q_size = cfg.num_attention_heads * cfg.head_dim + kv_size = cfg.num_key_value_heads * cfg.head_dim + hidden = cfg.hidden_size + inter = cfg.intermediate_size + + weights = { + "model.layers.0.self_attn.q_proj.weight": _make_int8_weight((q_size, hidden), 1), + "model.layers.0.self_attn.k_proj.weight": _make_int8_weight((kv_size, hidden), 2), + "model.layers.0.self_attn.v_proj.weight": _make_int8_weight((kv_size, hidden), 3), + "model.layers.0.self_attn.q_proj.weight_scale": torch.full((q_size, 1), 0.1), + "model.layers.0.self_attn.k_proj.weight_scale": torch.full((kv_size, 1), 0.2), + "model.layers.0.self_attn.v_proj.weight_scale": torch.full((kv_size, 1), 0.3), + "model.layers.0.mlp.gate_proj.weight": _make_int8_weight((inter, hidden), 4), + "model.layers.0.mlp.up_proj.weight": _make_int8_weight((inter, hidden), 5), + "model.layers.0.mlp.gate_proj.weight_scale": torch.full((inter, 1), 0.4), + "model.layers.0.mlp.up_proj.weight_scale": torch.full((inter, 1), 0.5), + } + + model.load_weights(weights.items()) + + qkv = layer0.self_attn.qkv_proj + assert torch.equal(qkv.weight[:q_size], weights["model.layers.0.self_attn.q_proj.weight"]) + assert torch.equal( + qkv.weight[q_size : q_size + kv_size], + weights["model.layers.0.self_attn.k_proj.weight"], + ) + assert torch.equal( + qkv.weight[q_size + kv_size : q_size + 2 * kv_size], + weights["model.layers.0.self_attn.v_proj.weight"], + ) + torch.testing.assert_close( + qkv.weight_scale[:q_size], + weights["model.layers.0.self_attn.q_proj.weight_scale"], + ) + torch.testing.assert_close( + qkv.weight_scale[q_size : q_size + kv_size], + weights["model.layers.0.self_attn.k_proj.weight_scale"], + ) + torch.testing.assert_close( + qkv.weight_scale[q_size + kv_size : q_size + 2 * kv_size], + weights["model.layers.0.self_attn.v_proj.weight_scale"], + ) + + gate_up = layer0.mlp.gate_up_proj + assert torch.equal( + gate_up.weight[:inter], + weights["model.layers.0.mlp.gate_proj.weight"], + ) + assert torch.equal( + gate_up.weight[inter : 2 * inter], + weights["model.layers.0.mlp.up_proj.weight"], + ) + torch.testing.assert_close( + gate_up.weight_scale[:inter], + weights["model.layers.0.mlp.gate_proj.weight_scale"], + ) + torch.testing.assert_close( + gate_up.weight_scale[inter : 2 * inter], + weights["model.layers.0.mlp.up_proj.weight_scale"], + ) diff --git a/pymllm/tests/test_rms_norm.py b/pymllm/tests/test_rms_norm.py new file mode 100644 index 000000000..9663f5444 --- /dev/null +++ b/pymllm/tests/test_rms_norm.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import torch + +import pymllm.layers.rms_norm as rms_norm_module +from pymllm.layers.rms_norm import RMSNorm + + +def test_rms_norm_residual_fallback_returns_updated_residual(monkeypatch): + def fail_fused_add_rmsnorm(*args, **kwargs): + del args, kwargs + raise RuntimeError("force torch fallback") + + monkeypatch.setattr( + rms_norm_module.flashinfer.norm, + "fused_add_rmsnorm", + fail_fused_add_rmsnorm, + ) + + norm = RMSNorm(hidden_size=3, eps=1e-6) + norm.weight.data.fill_(1.0) + x = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + residual = torch.tensor([[4.0, 5.0, 6.0]], dtype=torch.float32) + + _, residual_out = norm(x, residual) + + torch.testing.assert_close(residual_out, x + residual) diff --git a/pymllm/tests/test_server_debug_timing.py b/pymllm/tests/test_server_debug_timing.py new file mode 100644 index 000000000..32b1b6aef --- /dev/null +++ b/pymllm/tests/test_server_debug_timing.py @@ -0,0 +1,78 @@ +import pytest + +from pymllm.configs.global_config import GlobalConfig, make_args, read_args +from pymllm.configs.server_config import ServerConfig +from pymllm.server import launch + + +@pytest.fixture(autouse=True) +def reset_global_config(): + GlobalConfig.reset() + yield + GlobalConfig.reset() + + +def test_server_debug_timing_is_disabled_by_default(): + assert ServerConfig(model_path=None).enable_debug_timing is False + + +def test_server_debug_timing_can_be_enabled_from_cli(): + cfg = read_args( + ["--server.enable_debug_timing"], + parser=make_args(), + ) + + assert cfg.server.enable_debug_timing is True + + +def test_debug_timing_is_not_added_when_disabled(): + cfg = GlobalConfig.get_instance() + cfg.server.enable_debug_timing = False + payload = {"id": "chatcmpl-test"} + + assert hasattr(launch, "_maybe_add_debug_timing") + launch._maybe_add_debug_timing( + payload, + result={ + "vit_prefill_ms": 12.5, + "vit_prefill_tokens": 25, + "llm_prefill_ms": 50.0, + "llm_decode_ms": 200.0, + }, + prompt_tokens=100, + completion_tokens=20, + ) + + assert "timing" not in payload + assert "debug_timing" not in payload + + +def test_debug_timing_uses_debug_field_names_when_enabled(): + cfg = GlobalConfig.get_instance() + cfg.server.enable_debug_timing = True + payload = {"id": "chatcmpl-test"} + + assert hasattr(launch, "_maybe_add_debug_timing") + launch._maybe_add_debug_timing( + payload, + result={ + "vit_prefill_ms": 12.5, + "vit_prefill_tokens": 25, + "llm_prefill_ms": 50.0, + "llm_decode_ms": 200.0, + }, + prompt_tokens=100, + completion_tokens=20, + ) + + assert "timing" not in payload + assert payload["debug_timing"] == { + "experimental_vit_prefill_ms": 12.5, + "experimental_llm_prefill_ms": 50.0, + "decode_phase_wall_ms": 200.0, + "prefill_tokens": 100, + "completion_tokens": 20, + "experimental_vit_prefill_tps": pytest.approx(2000.0), + "experimental_llm_prefill_tps": pytest.approx(2000.0), + "decode_phase_output_tps": pytest.approx(100.0), + }