Skip to content

TransformerLens 3.1.0#1277

Merged
jlarson4 merged 21 commits intomainfrom
dev
Apr 30, 2026
Merged

TransformerLens 3.1.0#1277
jlarson4 merged 21 commits intomainfrom
dev

Conversation

@jlarson4
Copy link
Copy Markdown
Collaborator

Description

  • Resolving as many open contribution PRs for the legacy HookedTransformer system as possible
  • Resolving as many open Issues for the legacy HookedTransformer system as possible
  • Added Baichuan Adapter
  • Added Architecture Adapter Creation Guide to documentation to help contributors add new architectures to TransformerBridge
  • Added a fix that allows Quantized models to run properly on TransformerBridge
  • Added an experimental feature for device_map, needs testing on actual multi-CPU devices
  • Added generate_stream to HookedTransformer and TransformerBridge
  • Fixed a series of legacy bugs on HookedTransformer (see commit history)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

jlarson4 and others added 21 commits April 20, 2026 13:22
* Fixing `utils` imports

* skip gated notebooks on PR from forks

* Updating notebooks

* Ensure LLaMA only runs when HF_TOKEN is available
…it (#1260)

* fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init

TransformerLensKeyValueCacheEntry.init_cache_entry initialised
past_keys and past_values with torch.get_default_dtype(), which is
torch.float32 unless the caller has explicitly overridden the global
default. When a model runs in float16 or bfloat16, the subsequent
torch.cat([past_keys, new_keys], dim=1) inside append() promoted the
result to float32. Downstream attention-score computation then failed
with:

    RuntimeError: expected scalar type Half but found Float

at AbstractAttention.calculate_attention_scores (q_ @ k_ / attn_scale).

This blocked generate() with use_past_kv_cache=True (the default) for
any reduced-precision model. Disabling the KV cache worked but turned
generation into O(seq_len^2) per step, which is prohibitive for any
practical use.

The fix uses cfg.dtype — the same dtype the rest of the model is loaded
with. This is what every production fp16 inference stack does
(HuggingFace transformers, vLLM, TGI, llama.cpp, TensorRT-LLM).

Added tests/unit/test_key_value_cache_entry.py covering:
  - init_cache_entry respects cfg.dtype for fp32, fp16, bfloat16
  - behaviour is independent of torch.get_default_dtype()
  - append() preserves cfg.dtype without promoting to fp32
  - grouped-query-attention path uses n_key_value_heads correctly

* isort imports

---------

Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Fix apertus test failing on machines with GPU

Tensor equality includes the device, so set device="cpu" so
weight tensors always match expected, even if there's  GPU
they could be created on.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Fix test_cuda using nonexistent mlm_tokens fixture

The test_cuda function referenced a fixture named mlm_tokens which was
never defined, causing a fixture-not-found error. Changed to use the
existing tokens fixture which provides the same MLM-style tokenized input.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Resolve conflict

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonah Larson <jonahalarson@comcast.net>
#1215)

* fix: handle LayerNorm folding correctly in load_and_process_state_dict

Previously, calling load_and_process_state_dict(state_dict, fold_ln=True)
had two failure modes:

1. If the state_dict had unfolded LN weights, fold_layer_norm removed
   the LN keys but the model's modules were not replaced with LNPre,
   leaving mismatched architecture and broken hooks.

2. If the state_dict was already folded (no LN keys), fold_layer_norm
   crashed with a KeyError trying to access missing LN weight keys.

Fix both by:
- Checking whether LN keys exist before attempting to fold (skip with
  warning if already folded)
- Replacing LN/RMS modules with LNPre/RMSPre before folding, matching
  the logic previously only in process_weights_
- Calling self.setup() after loading to re-attach hooks
- Simplifying process_weights_ to delegate fully to the fixed method

Fixes #219

Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com>

* style: fix black formatting in HookedTransformer.py

* fix: handle LayerNorm folding correctly in load_and_process_state_dict

* Make sure not to double fold

---------

Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com>
Co-authored-by: jlarson4 <jonahalarson@comcast.net>
rotary_base is frequently set to floats in the code but was typed as
an int, causing beartype errors if the configs get loaded in a test:
https://github.com/TransformerLensOrg/TransformerLens/blob/9c5a2a81674d5bcefa641c816b66e9827ccdf637/transformer_lens/loading_from_pretrained.py#L1984

HF confgs' allegedly always have rope_theta as a float:
https://github.com/huggingface/transformers/blob/c38b2fb78eaedd4261a0e446f7976345cd1c7f1b/src/transformers/modeling_rope_utils.py#L645

But sometimes they're actually ints, and beartype doesn't consider int
to be a subtype of float:
beartype/beartype#66

This updates the type to Union[float, int] to be accurate while keeping
beartype happy.

Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* fixed batching in generate

* added test case

* Move & improve tests

* make check format and mypy

* fix mypy errors

* Stop jaxtyping failures

* Updated to also fix TransformerBridge for the same issue

---------

Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* adds HookedTransformer.generate_stream()

* fixes mypy errors

* Adjusted for TransformerLens 3 changes

---------

Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Add attention_mask argument to loss_fn() and lm_cross_entropy_loss() and adjust the cross entropy calculation to ignore masked (padding) tokens.

* updated lock file

* locked numpy belo 2

---------

Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Make `FactoredMatrix` compatible with tensor-like arguments

I'd like to be able to use `FactoredMatrix` with things that implement the interface of `torch.Tensor` without subclassing it.  This slight change allows `FactoredMatrix` to work with such classes rather than returning `None` in various places.

* Added test and properly typed he methods to include details of this fix

---------

Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
Co-authored-by: jlarson4 <jonahalarson@comcast.net>
… model (#629)

* Update convert_nanogpt_weights to have attention mask and handle case when there is no bias.

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* ran format

* Make beartyping dependency more forgiving

Signed-off-by: Dashiell Stander <dstander@protonmail.com>

* generated lock file

* Added an error if cfg.d_mlp is None

---------

Signed-off-by: Dashiell Stander <dstander@protonmail.com>
Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Added n_ctx override to TransformerBridge

* Prevent output of progress bars in T5 demo
* adds HookedTransformer.generate_stream()

* fixes mypy errors

* Adjusted for TransformerLens 3 changes

* Initial bridge generate stream

* TransformerBridge Generate Stream

---------

Co-authored-by: anthonyduong <anthonyduong9@gmail.com>
Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
* Improved issue with tokenize and concatenate

* dialed in the approach to be per-doc
* Multi-GPU initial setup for TransformerBridge

* Added additional documentation note
* Adding architecture Adapter creation guide, add split QKV example to quantized LLaMA demo

* ignore docs/build from black linting
* Fixed Quantization bug in TransformerLens 3.0

* Format fixes
@jlarson4 jlarson4 merged commit 6f56518 into main Apr 30, 2026
75 of 88 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants