fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init#1260
Merged
jlarson4 merged 3 commits intoTransformerLensOrg:devfrom Apr 20, 2026
Merged
Conversation
TransformerLensKeyValueCacheEntry.init_cache_entry initialised
past_keys and past_values with torch.get_default_dtype(), which is
torch.float32 unless the caller has explicitly overridden the global
default. When a model runs in float16 or bfloat16, the subsequent
torch.cat([past_keys, new_keys], dim=1) inside append() promoted the
result to float32. Downstream attention-score computation then failed
with:
RuntimeError: expected scalar type Half but found Float
at AbstractAttention.calculate_attention_scores (q_ @ k_ / attn_scale).
This blocked generate() with use_past_kv_cache=True (the default) for
any reduced-precision model. Disabling the KV cache worked but turned
generation into O(seq_len^2) per step, which is prohibitive for any
practical use.
The fix uses cfg.dtype — the same dtype the rest of the model is loaded
with. This is what every production fp16 inference stack does
(HuggingFace transformers, vLLM, TGI, llama.cpp, TensorRT-LLM).
Added tests/unit/test_key_value_cache_entry.py covering:
- init_cache_entry respects cfg.dtype for fp32, fp16, bfloat16
- behaviour is independent of torch.get_default_dtype()
- append() preserves cfg.dtype without promoting to fp32
- grouped-query-attention path uses n_key_value_heads correctly
Collaborator
|
Thanks for fixing this! I had to clean up some CI inconsistencies from 3.0 to get this to pass, it will be in the next release. |
Contributor
Author
|
Sure thing Jonah! Thanks for the help getting it merged and all the rest you do for this project |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
TransformerLensKeyValueCacheEntry.init_cache_entryinitialisespast_keysandpast_valueswithtorch.get_default_dtype(), which istorch.float32unless the caller has set a different global default. When a model is loaded infloat16orbfloat16, the subsequenttorch.cat([past_keys, new_keys], dim=1)insideappend()promotes the result tofloat32. The downstream attention-score computation then fails atAbstractAttention.calculate_attention_scores:at
attn_scores = q_ @ k_ / self.attn_scale.This blocks
generate()withuse_past_kv_cache=True(the default) for any reduced-precision model. Disabling the KV cache works but turns generation into O(seq_len²) per step, which is prohibitive in practice. I hit this running Llama 3.1 8B at fp16 on a 3090; the same reproducer triggers on any fp16/bf16 model withuse_past_kv_cache=True.Reproducer
Fix
Use
cfg.dtype— the same dtype the rest of the model is loaded with. This matches what many other production fp16 inference stacks do (HuggingFacetransformers, vLLM, TGI, llama.cpp, TensorRT-LLM).The KV cache tensors are initially shape
(batch, 0, n_heads, d_head)— empty in the sequence dimension — so the dtype change only controls what format the subsequenttorch.catproduces. No precision is lost or quantised that wouldn't already be lost by the model being in fp16.Tests
Added
tests/unit/test_key_value_cache_entry.pycovering:init_cache_entryrespectscfg.dtypeforfp32,fp16,bfloat16torch.get_default_dtype()(regression guard)append()preservescfg.dtypewithout promoting to fp32n_key_value_headscorrectlyType of change
Checklist