Skip to content

Questions about supporting KV Cache quantization for models that do not support quantized cache now #33231

@huangyuxiang03

Description

@huangyuxiang03

System Info

  • transformers version: 4.44.2
  • Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.17
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.3
  • Safetensors version: 0.4.3
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA H800

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers.cache_utils import QuantoQuantizedCache, QuantizedCacheConfig

BS = 1024
@torch.no_grad()
def gen(model, input_ids, max_new_tokens, eos_token_id):
    past_key_values = QuantoQuantizedCache(QuantizedCacheConfig(nbits=2, compute_dtype=torch.bfloat16))
    for b in range(0, input_ids.shape[-1], BS):
        e = min(input_ids.shape[-1], b + BS)
        output = model(input_ids[:, b:e], past_key_values=past_key_values)
        past_key_values = output.past_key_values

    generated_tokens = []
    input_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
    generated_tokens.append(input_id.item())

    for _ in range(max_new_tokens-1):
        output = model(input_id, past_key_values=past_key_values)
        past_key_values = output.past_key_values
        input_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        if input_id.item() == eos_token_id:
            break
        generated_tokens.append(input_id.item())
    generated_tokens = torch.tensor(generated_tokens, device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)
    input_ids = torch.cat((input_ids, generated_tokens), dim=-1)
    return input_ids

Expected behavior

The code snippet provided above generates random output for phi-3-mini-128K, which is a model that does not originally support KV Cache quantization.

However, from my understanding of the quantized cache supported in Hugging Face Transformers, one can simple replace an instance of DynamicCache to QuantoQuantizedCache to enable KV Cache quantization. This is also mentioned in #30483 (comment). Phi-3-mini-128K is a quite-standard decoder-only transformer-based model with only a few modifications on model structure compared with Llama, thus I believe that if the quantized cache can work correctly (which it does) on Llama, it can work correctly on Phi-3. The code snippet can generate high quality output on Llama, but it generates random tokens on Phi-3.

Besides, could you provide a readme to teach model contributors to enable KV Cache quantization?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions