Skip to content

fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init#1260

Merged
jlarson4 merged 3 commits intoTransformerLensOrg:devfrom
davidcyze:fix/kv-cache-dtype-matches-cfg
Apr 20, 2026
Merged

fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init#1260
jlarson4 merged 3 commits intoTransformerLensOrg:devfrom
davidcyze:fix/kv-cache-dtype-matches-cfg

Conversation

@davidcyze
Copy link
Copy Markdown
Contributor

Description

TransformerLensKeyValueCacheEntry.init_cache_entry initialises past_keys and past_values with torch.get_default_dtype(), which is torch.float32 unless the caller has set a different global default. When a model is loaded in float16 or bfloat16, the subsequent torch.cat([past_keys, new_keys], dim=1) inside append() promotes the result to float32. The downstream attention-score computation then fails at AbstractAttention.calculate_attention_scores:

RuntimeError: expected scalar type Half but found Float

at attn_scores = q_ @ k_ / self.attn_scale.

This blocks generate() with use_past_kv_cache=True (the default) for any reduced-precision model. Disabling the KV cache works but turns generation into O(seq_len²) per step, which is prohibitive in practice. I hit this running Llama 3.1 8B at fp16 on a 3090; the same reproducer triggers on any fp16/bf16 model with use_past_kv_cache=True.

Reproducer

import torch
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gpt2", dtype=torch.float16, device="cuda")
tokens = model.to_tokens("Hello world")
model.generate(tokens, max_new_tokens=16, use_past_kv_cache=True)
# RuntimeError: expected scalar type Half but found Float

Fix

Use cfg.dtype — the same dtype the rest of the model is loaded with. This matches what many other production fp16 inference stacks do (HuggingFace transformers, vLLM, TGI, llama.cpp, TensorRT-LLM).

The KV cache tensors are initially shape (batch, 0, n_heads, d_head) — empty in the sequence dimension — so the dtype change only controls what format the subsequent torch.cat produces. No precision is lost or quantised that wouldn't already be lost by the model being in fp16.

Tests

Added tests/unit/test_key_value_cache_entry.py covering:

  • init_cache_entry respects cfg.dtype for fp32, fp16, bfloat16
  • behaviour is independent of torch.get_default_dtype() (regression guard)
  • append() preserves cfg.dtype without promoting to fp32
  • GQA path uses n_key_value_heads correctly

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist

  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

davidcyze and others added 3 commits April 19, 2026 22:50
TransformerLensKeyValueCacheEntry.init_cache_entry initialised
past_keys and past_values with torch.get_default_dtype(), which is
torch.float32 unless the caller has explicitly overridden the global
default. When a model runs in float16 or bfloat16, the subsequent
torch.cat([past_keys, new_keys], dim=1) inside append() promoted the
result to float32. Downstream attention-score computation then failed
with:

    RuntimeError: expected scalar type Half but found Float

at AbstractAttention.calculate_attention_scores (q_ @ k_ / attn_scale).

This blocked generate() with use_past_kv_cache=True (the default) for
any reduced-precision model. Disabling the KV cache worked but turned
generation into O(seq_len^2) per step, which is prohibitive for any
practical use.

The fix uses cfg.dtype — the same dtype the rest of the model is loaded
with. This is what every production fp16 inference stack does
(HuggingFace transformers, vLLM, TGI, llama.cpp, TensorRT-LLM).

Added tests/unit/test_key_value_cache_entry.py covering:
  - init_cache_entry respects cfg.dtype for fp32, fp16, bfloat16
  - behaviour is independent of torch.get_default_dtype()
  - append() preserves cfg.dtype without promoting to fp32
  - grouped-query-attention path uses n_key_value_heads correctly
@jlarson4
Copy link
Copy Markdown
Collaborator

Thanks for fixing this! I had to clean up some CI inconsistencies from 3.0 to get this to pass, it will be in the next release.

@jlarson4 jlarson4 merged commit 9ef4e4c into TransformerLensOrg:dev Apr 20, 2026
43 of 44 checks passed
@davidcyze
Copy link
Copy Markdown
Contributor Author

Sure thing Jonah! Thanks for the help getting it merged and all the rest you do for this project

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants