From 7c95d7c624daa452a47fd7eb022b168f9fe250db Mon Sep 17 00:00:00 2001 From: David Cyze Date: Sun, 19 Apr 2026 22:30:28 -0500 Subject: [PATCH 1/2] fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TransformerLensKeyValueCacheEntry.init_cache_entry initialised past_keys and past_values with torch.get_default_dtype(), which is torch.float32 unless the caller has explicitly overridden the global default. When a model runs in float16 or bfloat16, the subsequent torch.cat([past_keys, new_keys], dim=1) inside append() promoted the result to float32. Downstream attention-score computation then failed with: RuntimeError: expected scalar type Half but found Float at AbstractAttention.calculate_attention_scores (q_ @ k_ / attn_scale). This blocked generate() with use_past_kv_cache=True (the default) for any reduced-precision model. Disabling the KV cache worked but turned generation into O(seq_len^2) per step, which is prohibitive for any practical use. The fix uses cfg.dtype — the same dtype the rest of the model is loaded with. This is what every production fp16 inference stack does (HuggingFace transformers, vLLM, TGI, llama.cpp, TensorRT-LLM). Added tests/unit/test_key_value_cache_entry.py covering: - init_cache_entry respects cfg.dtype for fp32, fp16, bfloat16 - behaviour is independent of torch.get_default_dtype() - append() preserves cfg.dtype without promoting to fp32 - grouped-query-attention path uses n_key_value_heads correctly --- tests/unit/test_key_value_cache_entry.py | 104 ++++++++++++++++++ .../cache/key_value_cache_entry.py | 10 +- 2 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_key_value_cache_entry.py diff --git a/tests/unit/test_key_value_cache_entry.py b/tests/unit/test_key_value_cache_entry.py new file mode 100644 index 000000000..a91f32482 --- /dev/null +++ b/tests/unit/test_key_value_cache_entry.py @@ -0,0 +1,104 @@ +"""Tests for TransformerLensKeyValueCacheEntry.init_cache_entry dtype behaviour. + +The buggy pre-fix code used ``torch.get_default_dtype()`` to initialise +``past_keys`` and ``past_values``. PyTorch's default is ``torch.float32``, +so the bug silently produced the correct dtype for fp32 models but the +wrong dtype (fp32 instead of fp16/bf16) for reduced-precision ones. Of +the tests below, ``test_init_cache_entry_uses_cfg_dtype_float32`` is +therefore a baseline sanity check that passes against both the buggy +and fixed code — it verifies the common case works, not that the bug is +absent. The real regression guards are +``test_init_cache_entry_uses_cfg_dtype_float16``, +``..._bfloat16``, ``..._dtype_independent_of_global_default``, and +``test_append_preserves_cfg_dtype``, which all fail against the buggy +code (the fp16 cache was getting promoted to fp32 by the bug, breaking +the downstream attention-score matmul). +""" + +import torch + +from transformer_lens.cache.key_value_cache_entry import TransformerLensKeyValueCacheEntry +from transformer_lens.config.TransformerLensConfig import TransformerLensConfig + + +def _make_cfg(dtype: torch.dtype, n_heads: int = 4, d_head: int = 8, n_key_value_heads=None): + return TransformerLensConfig( + d_model=n_heads * d_head, + d_head=d_head, + n_layers=1, + n_ctx=32, + n_heads=n_heads, + n_key_value_heads=n_key_value_heads, + dtype=dtype, + ) + + +def test_init_cache_entry_uses_cfg_dtype_float32(): + """Baseline: cfg.dtype=float32 produces fp32 buffers. + + Note: this test passes against both the buggy and fixed implementations + because torch's default dtype is also float32. It is a sanity check + that the common case works, not a regression guard for the specific + bug this module was added to prevent. See module docstring and + ``test_init_cache_entry_dtype_independent_of_global_default`` for the + tests that discriminate fix vs bug. + """ + cfg = _make_cfg(dtype=torch.float32) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.float32 + assert entry.past_values.dtype == torch.float32 + + +def test_init_cache_entry_uses_cfg_dtype_float16(): + cfg = _make_cfg(dtype=torch.float16) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.float16 + assert entry.past_values.dtype == torch.float16 + + +def test_init_cache_entry_uses_cfg_dtype_bfloat16(): + cfg = _make_cfg(dtype=torch.bfloat16) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.bfloat16 + assert entry.past_values.dtype == torch.bfloat16 + + +def test_init_cache_entry_dtype_independent_of_global_default(): + """Regression guard: cache dtype follows cfg.dtype, not the global default. + + Also covers the fp32 case indirectly: if someone reintroduces the old + ``torch.get_default_dtype()`` behaviour, this test plus the fp16 / + bfloat16 / append / GQA tests catch it; the fp32-only baseline above + would not, since fp32 happens to be torch's global default. + """ + cfg = _make_cfg(dtype=torch.float16) + original_default = torch.get_default_dtype() + try: + torch.set_default_dtype(torch.float32) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + assert entry.past_keys.dtype == torch.float16 + assert entry.past_values.dtype == torch.float16 + finally: + torch.set_default_dtype(original_default) + + +def test_append_preserves_cfg_dtype(): + """After append, past_keys stays in cfg.dtype — no float promotion.""" + cfg = _make_cfg(dtype=torch.float16) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu") + new_keys = torch.randn(1, 3, cfg.n_heads, cfg.d_head, dtype=torch.float16) + new_values = torch.randn(1, 3, cfg.n_heads, cfg.d_head, dtype=torch.float16) + updated_keys, updated_values = entry.append(new_keys, new_values) + assert updated_keys.dtype == torch.float16 + assert updated_values.dtype == torch.float16 + assert entry.past_keys.dtype == torch.float16 + assert entry.past_values.dtype == torch.float16 + + +def test_init_cache_entry_handles_grouped_query_attention(): + """When n_key_value_heads is set (GQA), it should be used instead of n_heads.""" + cfg = _make_cfg(dtype=torch.float16, n_heads=32, d_head=128, n_key_value_heads=8) + entry = TransformerLensKeyValueCacheEntry.init_cache_entry(cfg, device="cpu", batch_size=2) + assert entry.past_keys.shape == (2, 0, 8, 128) + assert entry.past_values.shape == (2, 0, 8, 128) + assert entry.past_keys.dtype == torch.float16 diff --git a/transformer_lens/cache/key_value_cache_entry.py b/transformer_lens/cache/key_value_cache_entry.py index eec478f7e..b8c8c57a8 100644 --- a/transformer_lens/cache/key_value_cache_entry.py +++ b/transformer_lens/cache/key_value_cache_entry.py @@ -27,12 +27,18 @@ def init_cache_entry( batch_size: int = 1, ): n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads + # Use cfg.dtype so the cache matches the model's dtype. Using + # torch.get_default_dtype() (which is float32 unless the caller has + # set it) caused the subsequent torch.cat([past_keys, new_keys]) to + # promote the result to float32 when the model runs in float16 or + # bfloat16, which in turn broke the attention-score matmul with + # "expected scalar type Half but found Float". return cls( past_keys=torch.empty( - (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=torch.get_default_dtype() + (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype ), past_values=torch.empty( - (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=torch.get_default_dtype() + (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype ), ) From 4679a71c828163bea835c53c9285eec8b0b2ca1b Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 20 Apr 2026 13:28:53 -0500 Subject: [PATCH 2/2] isort imports --- tests/unit/test_key_value_cache_entry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_key_value_cache_entry.py b/tests/unit/test_key_value_cache_entry.py index a91f32482..b7806454a 100644 --- a/tests/unit/test_key_value_cache_entry.py +++ b/tests/unit/test_key_value_cache_entry.py @@ -17,7 +17,9 @@ import torch -from transformer_lens.cache.key_value_cache_entry import TransformerLensKeyValueCacheEntry +from transformer_lens.cache.key_value_cache_entry import ( + TransformerLensKeyValueCacheEntry, +) from transformer_lens.config.TransformerLensConfig import TransformerLensConfig