Skip to content

Add FP8 kernel acceleration for compressed-tensors quantized models#45699

Draft
jiqing-feng wants to merge 4 commits intohuggingface:mainfrom
jiqing-feng:fp8
Draft

Add FP8 kernel acceleration for compressed-tensors quantized models#45699
jiqing-feng wants to merge 4 commits intohuggingface:mainfrom
jiqing-feng:fp8

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented Apr 29, 2026

What does this PR do?

This PR adds native FP8 matmul kernel support for compressed-tensors FP8 quantized models in transformers. Previously, compressed-tensors FP8 models were loaded via the compressed-tensors library and dequantized back to FP16/BF16 for inference. With this change, FP8 weights are kept in FP8 format and inference uses hardware-accelerated FP8 matmul kernels (torch._scaled_mm on XPU, fbgemm.f8f8bf16_rowwise on CUDA).

Key changes:

New file: src/transformers/integrations/compressed_tensors_fp8.py

  • CTFP8Linear: FP8 linear layer that stores weights in FP8 and uses row-wise FP8 matmul kernels. Activations are dynamically quantized per-row via quantize_fp8_per_row.
  • Weight converters (CompressedTensorsScaleConvert, CompressedTensorsFp8Dequantize) to handle the checkpoint format conversion (e.g. weight_scaleweight_scale_inv).
  • CTFP8PerRowQuantize: Online quantization support — quantize BF16 weights to FP8 per-row on-the-fly during model loading.

Modified: src/transformers/quantizers/quantizer_compressed_tensors.py

  • CompressedTensorsHfQuantizer now detects FP8 quantization configs (float type, num_bits=8) and automatically routes to the FP8 kernel path when GPU/XPU is available. Falls back to the default compressed-tensors dequantize path on CPU.
  • Added get_weight_conversions() and get_quantize_ops() to support both pre-quantized loading and online quantization.
  • No changes to the non-FP8 code path — existing INT8/INT4 compressed-tensors models are unaffected.

Modified: src/transformers/quantizers/auto.py

  • Minor formatting change (no functional change).

Supported models

Usage

Pre-quantized model (no config needed)

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Online quantization

from transformers import AutoModelForCausalLM, CompressedTensorsConfig
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs, QuantizationType, QuantizationStrategy

ct_config = CompressedTensorsConfig(
    config_groups={
        "group_0": QuantizationScheme(
            weights=QuantizationArgs(
                num_bits=8, type=QuantizationType.FLOAT, strategy=QuantizationStrategy.CHANNEL,
            ),
        ),
    },
    run_compressed=True,
)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct",
    quantization_config=ct_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

Devices

  • XPU (Intel Data Center Max / Arc): uses torch._scaled_mm
  • CUDA (SM89+): uses fbgemm.f8f8bf16_rowwise
  • CPU: falls back to default compressed-tensors dequantize path

@sywangyi

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng changed the title Fp8 Add FP8 kernel acceleration for compressed-tensors quantized models Apr 29, 2026
@Rocketknight1
Copy link
Copy Markdown
Member

cc @SunMarc

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left a comment !

Comment on lines +54 to +55
- XPU: torch._scaled_mm
- CUDA: fbgemm.f8f8bf16_rowwise
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the model is dequantized on the fly even if run_compressed=True (it worked before but compressed-tensors preferred to delegate this work to vllm as it was duplicate work). We could add support for the most used methods here in transformers but it would be nice to use kernels from vllm if possible. Also, I don't know if this is worth using fbgemm. torchao don't use this lib anymore and it was quite a struggle to get this installed. The best would be to use kernels hosted on kernels-community as we do for our current fp8 support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants