Conversation
* Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available
…it (#1260) * fix: use cfg.dtype instead of torch.get_default_dtype for KV cache init TransformerLensKeyValueCacheEntry.init_cache_entry initialised past_keys and past_values with torch.get_default_dtype(), which is torch.float32 unless the caller has explicitly overridden the global default. When a model runs in float16 or bfloat16, the subsequent torch.cat([past_keys, new_keys], dim=1) inside append() promoted the result to float32. Downstream attention-score computation then failed with: RuntimeError: expected scalar type Half but found Float at AbstractAttention.calculate_attention_scores (q_ @ k_ / attn_scale). This blocked generate() with use_past_kv_cache=True (the default) for any reduced-precision model. Disabling the KV cache worked but turned generation into O(seq_len^2) per step, which is prohibitive for any practical use. The fix uses cfg.dtype — the same dtype the rest of the model is loaded with. This is what every production fp16 inference stack does (HuggingFace transformers, vLLM, TGI, llama.cpp, TensorRT-LLM). Added tests/unit/test_key_value_cache_entry.py covering: - init_cache_entry respects cfg.dtype for fp32, fp16, bfloat16 - behaviour is independent of torch.get_default_dtype() - append() preserves cfg.dtype without promoting to fp32 - grouped-query-attention path uses n_key_value_heads correctly * isort imports --------- Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Fix apertus test failing on machines with GPU Tensor equality includes the device, so set device="cpu" so weight tensors always match expected, even if there's GPU they could be created on. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix test_cuda using nonexistent mlm_tokens fixture The test_cuda function referenced a fixture named mlm_tokens which was never defined, causing a fixture-not-found error. Changed to use the existing tokens fixture which provides the same MLM-style tokenized input. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Resolve conflict --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Jonah Larson <jonahalarson@comcast.net>
#1215) * fix: handle LayerNorm folding correctly in load_and_process_state_dict Previously, calling load_and_process_state_dict(state_dict, fold_ln=True) had two failure modes: 1. If the state_dict had unfolded LN weights, fold_layer_norm removed the LN keys but the model's modules were not replaced with LNPre, leaving mismatched architecture and broken hooks. 2. If the state_dict was already folded (no LN keys), fold_layer_norm crashed with a KeyError trying to access missing LN weight keys. Fix both by: - Checking whether LN keys exist before attempting to fold (skip with warning if already folded) - Replacing LN/RMS modules with LNPre/RMSPre before folding, matching the logic previously only in process_weights_ - Calling self.setup() after loading to re-attach hooks - Simplifying process_weights_ to delegate fully to the fixed method Fixes #219 Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> * style: fix black formatting in HookedTransformer.py * fix: handle LayerNorm folding correctly in load_and_process_state_dict * Make sure not to double fold --------- Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Co-authored-by: jlarson4 <jonahalarson@comcast.net>
rotary_base is frequently set to floats in the code but was typed as an int, causing beartype errors if the configs get loaded in a test: https://github.com/TransformerLensOrg/TransformerLens/blob/9c5a2a81674d5bcefa641c816b66e9827ccdf637/transformer_lens/loading_from_pretrained.py#L1984 HF confgs' allegedly always have rope_theta as a float: https://github.com/huggingface/transformers/blob/c38b2fb78eaedd4261a0e446f7976345cd1c7f1b/src/transformers/modeling_rope_utils.py#L645 But sometimes they're actually ints, and beartype doesn't consider int to be a subtype of float: beartype/beartype#66 This updates the type to Union[float, int] to be accurate while keeping beartype happy. Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* fixed batching in generate * added test case * Move & improve tests * make check format and mypy * fix mypy errors * Stop jaxtyping failures * Updated to also fix TransformerBridge for the same issue --------- Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* adds HookedTransformer.generate_stream() * fixes mypy errors * Adjusted for TransformerLens 3 changes --------- Co-authored-by: Bryce Meyer <bryce13950@gmail.com> Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Add attention_mask argument to loss_fn() and lm_cross_entropy_loss() and adjust the cross entropy calculation to ignore masked (padding) tokens. * updated lock file * locked numpy belo 2 --------- Co-authored-by: Bryce Meyer <bryce13950@gmail.com> Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Make `FactoredMatrix` compatible with tensor-like arguments I'd like to be able to use `FactoredMatrix` with things that implement the interface of `torch.Tensor` without subclassing it. This slight change allows `FactoredMatrix` to work with such classes rather than returning `None` in various places. * Added test and properly typed he methods to include details of this fix --------- Co-authored-by: Bryce Meyer <bryce13950@gmail.com> Co-authored-by: jlarson4 <jonahalarson@comcast.net>
… model (#629) * Update convert_nanogpt_weights to have attention mask and handle case when there is no bias. Signed-off-by: Dashiell Stander <dstander@protonmail.com> * ran format * Make beartyping dependency more forgiving Signed-off-by: Dashiell Stander <dstander@protonmail.com> * generated lock file * Added an error if cfg.d_mlp is None --------- Signed-off-by: Dashiell Stander <dstander@protonmail.com> Co-authored-by: Bryce Meyer <bryce13950@gmail.com> Co-authored-by: jlarson4 <jonahalarson@comcast.net>
* Added n_ctx override to TransformerBridge * Prevent output of progress bars in T5 demo
* adds HookedTransformer.generate_stream() * fixes mypy errors * Adjusted for TransformerLens 3 changes * Initial bridge generate stream * TransformerBridge Generate Stream --------- Co-authored-by: anthonyduong <anthonyduong9@gmail.com> Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
… compatibility mode
* Improved issue with tokenize and concatenate * dialed in the approach to be per-doc
* Multi-GPU initial setup for TransformerBridge * Added additional documentation note
* Adding architecture Adapter creation guide, add split QKV example to quantized LLaMA demo * ignore docs/build from black linting
* Fixed Quantization bug in TransformerLens 3.0 * Format fixes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
HookedTransformersystem as possibleHookedTransformersystem as possibleTransformerBridgedevice_map, needs testing on actual multi-CPU devicesgenerate_streamto HookedTransformer and TransformerBridgeType of change
Please delete options that are not relevant.
Checklist: