Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions demos/T5.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@
"generated token: \",\", token id: 6\n",
"generated token: \"comment\", token id: 1670\n",
"generated token: \"\", token id: 3\n",
"generated token: \"\u00eates\", token id: 6738\n",
"generated token: \"êtes\", token id: 6738\n",
"generated token: \"-\", token id: 18\n",
"generated token: \"vous\", token id: 3249\n",
"generated token: \"\", token id: 3\n",
"generated token: \"?\", token id: 58\n",
"generated token: \"</s>\", token id: 1\n",
"translate English to French: Hello, how are you? \n",
" Bonjour, comment \u00eates-vous?\n"
" Bonjour, comment êtes-vous?\n"
]
}
],
Expand Down Expand Up @@ -206,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2026-03-05T18:28:00.478310Z",
Expand All @@ -215,21 +215,8 @@
"shell.execute_reply": "2026-03-05T18:28:00.629766Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hallo, magst du Bananen?\n"
]
}
],
"source": [
"prompt=\"translate English to German: Hello, do you like bananas?\"\n",
"\n",
"output = model.generate(prompt, do_sample=False, max_new_tokens=20)\n",
"print(output)"
]
"outputs": [],
"source": "prompt=\"translate English to German: Hello, do you like bananas?\"\n\noutput = model.generate(prompt, do_sample=False, max_new_tokens=20, verbose=False)\nprint(output)"
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -928,7 +915,7 @@
"outputs": [],
"source": [
"encoder_attn_pattern = cache[\"encoder_blocks.0.attn.hook_pattern\"]\n",
"input_str_tokens = [w.lstrip(\"\u2581\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]"
"input_str_tokens = [w.lstrip(\"\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]"
]
},
{
Expand Down Expand Up @@ -993,14 +980,14 @@
"data": {
"text/plain": [
"['<pad>',\n",
" '\u2581Bonjour',\n",
" '▁Bonjour',\n",
" ',',\n",
" '\u2581comment',\n",
" '\u2581',\n",
" '\u00eates',\n",
" '▁comment',\n",
" '',\n",
" 'êtes',\n",
" '-',\n",
" 'vous',\n",
" '\u2581',\n",
" '',\n",
" '?',\n",
" '</s>']"
]
Expand Down Expand Up @@ -1143,4 +1130,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
117 changes: 117 additions & 0 deletions tests/acceptance/model_bridge/test_n_ctx_override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for the n_ctx override parameter on TransformerBridge.boot_transformers().

Uses load_weights=False so we can verify config plumbing without fighting HF's
weight-loading checks. Models with learned positional embeddings (e.g. GPT-2)
cannot have their n_ctx reduced at weight-load time — only rotary models can
freely resize. These tests verify the config is written correctly; users are
responsible for choosing n_ctx values their model supports.
"""

import logging

import pytest

from transformer_lens.model_bridge import TransformerBridge


def test_n_ctx_override_writes_to_correct_hf_field():
"""For GPT-2 the field is n_positions — overriding n_ctx should update it."""
bridge = TransformerBridge.boot_transformers(
"gpt2", device="cpu", n_ctx=256, load_weights=False
)
assert bridge.cfg.n_ctx == 256
assert bridge.original_model.config.n_positions == 256


def test_n_ctx_default_uses_model_max():
"""Without an override, cfg.n_ctx reflects the HF config's value."""
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu", load_weights=False)
# GPT-2's n_positions default is 1024
assert bridge.cfg.n_ctx == 1024


def test_n_ctx_warns_when_above_default(caplog):
"""Overriding n_ctx above the model default should emit a logging.warning."""
with caplog.at_level(logging.WARNING):
TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=2048, load_weights=False)
assert any(
"larger than the model's default context length" in rec.message for rec in caplog.records
)


def test_n_ctx_combined_with_hf_config_overrides():
"""Explicit n_ctx should take precedence over hf_config_overrides for that field."""
bridge = TransformerBridge.boot_transformers(
"gpt2",
device="cpu",
n_ctx=256,
hf_config_overrides={"n_positions": 512}, # should be overridden by n_ctx=256
load_weights=False,
)
assert bridge.cfg.n_ctx == 256


# --- Coverage for code-review items #2, #4, #5, #7 ---


def test_n_ctx_zero_raises_value_error():
"""#2: n_ctx must be positive; zero should raise ValueError."""
with pytest.raises(ValueError, match="positive integer"):
TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=0, load_weights=False)


def test_n_ctx_negative_raises_value_error():
"""#2: n_ctx must be positive; negative should raise ValueError."""
with pytest.raises(ValueError, match="positive integer"):
TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=-1, load_weights=False)


def test_n_ctx_conflict_with_hf_config_overrides_warns(caplog):
"""#4: When both n_ctx and the same hf_config_overrides field are set with different values,
a warning should be emitted explaining that n_ctx wins."""
with caplog.at_level(logging.WARNING):
TransformerBridge.boot_transformers(
"gpt2",
device="cpu",
n_ctx=256,
hf_config_overrides={"n_positions": 512},
load_weights=False,
)
assert any(
"Both n_ctx=256 and hf_config_overrides['n_positions']" in rec.message
and "takes precedence" in rec.message
for rec in caplog.records
)


def test_n_ctx_no_conflict_when_values_match(caplog):
"""#4: If n_ctx and hf_config_overrides agree on the value, no conflict warning is emitted."""
with caplog.at_level(logging.WARNING):
TransformerBridge.boot_transformers(
"gpt2",
device="cpu",
n_ctx=256,
hf_config_overrides={"n_positions": 256}, # same as n_ctx
load_weights=False,
)
assert not any("takes precedence" in rec.message for rec in caplog.records)


def test_n_ctx_shrink_with_load_weights_gives_clear_error():
"""#5: Shrinking a learned-pos-embed model's n_ctx at weight-load time should raise
with a message explaining the cause and suggesting alternatives."""
with pytest.raises(RuntimeError) as exc_info:
TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=256, load_weights=True)
err = str(exc_info.value)
assert "n_ctx=256" in err
assert "learned positional embeddings" in err or "load_weights=False" in err


def test_n_ctx_override_verified_on_loaded_model():
"""#7: After load, the override should be visible on hf_model.config so users
can trust that the longer/shorter context is actually in effect."""
bridge = TransformerBridge.boot_transformers(
"gpt2", device="cpu", n_ctx=2048, load_weights=False
)
# The override persisted through model construction
assert bridge.original_model.config.n_positions == 2048
5 changes: 5 additions & 0 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def boot_transformers(
trust_remote_code: bool = False,
model_class: Optional[type] = None,
hf_model: Optional[Any] = None,
n_ctx: Optional[int] = None,
) -> "TransformerBridge":
"""Boot a model from HuggingFace (alias for sources.transformers.boot).

Expand All @@ -183,6 +184,9 @@ def boot_transformers(
hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful
for models loaded with custom configurations (e.g., quantization via
BitsAndBytesConfig). When provided, load_weights is ignored.
n_ctx: Optional context length override. Writes to the appropriate HF config field
for this model automatically (callers don't need to know the field name).
Warns if larger than the model's default context length.

Returns:
The bridge to the loaded model.
Expand All @@ -199,6 +203,7 @@ def boot_transformers(
trust_remote_code=trust_remote_code,
model_class=model_class,
hf_model=hf_model,
n_ctx=n_ctx,
)

@property
Expand Down
89 changes: 88 additions & 1 deletion transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def boot(
trust_remote_code: bool = False,
model_class: Any | None = None,
hf_model: Any | None = None,
n_ctx: int | None = None,
) -> TransformerBridge:
"""Boot a model from HuggingFace.

Expand All @@ -302,6 +303,11 @@ def boot(
hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for
models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig).
When provided, load_weights is ignored.
n_ctx: Optional context length override. The bridge normally uses the model's documented
max context from the HF config. Setting this writes to whichever HF field the model
uses (n_positions / max_position_embeddings / etc.), so callers don't need to know
the field name. If larger than the model's default, a warning is emitted — quality
may degrade past the trained length for rotary models.

Returns:
The bridge to the loaded model.
Expand All @@ -323,6 +329,54 @@ def boot(
trust_remote_code=trust_remote_code,
token=_hf_token,
)
_n_ctx_field: str | None = None
if n_ctx is not None:
# Validation (#2): reject non-positive values before doing anything else.
if n_ctx <= 0:
raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.")
# Resolve n_ctx to whichever HF config field this model uses. Mirrors
# the order in map_default_transformer_lens_config so the TL config
# derivation picks up the override.
for _field in (
"n_positions",
"max_position_embeddings",
"max_context_length",
"max_length",
"seq_length",
):
if hasattr(hf_config, _field):
_n_ctx_field = _field
break
if _n_ctx_field is None:
raise ValueError(
f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on "
f"HF config for {model_name}. Use hf_config_overrides instead."
)
_default_n_ctx = getattr(hf_config, _n_ctx_field)
if _default_n_ctx is not None and n_ctx > _default_n_ctx:
logging.warning(
"Setting n_ctx=%d which is larger than the model's default "
"context length of %d. The model was not trained on sequences "
"this long and may produce unreliable results (especially for "
"rotary models without RoPE scaling).",
n_ctx,
_default_n_ctx,
)
# Conflict detection (#4): warn if the caller also set the same field
# via hf_config_overrides — explicit n_ctx wins but users should know.
if hf_config_overrides and _n_ctx_field in hf_config_overrides:
_conflicting_value = hf_config_overrides[_n_ctx_field]
if _conflicting_value != n_ctx:
logging.warning(
"Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. "
"The explicit n_ctx takes precedence.",
n_ctx,
_n_ctx_field,
_conflicting_value,
)
# Explicit n_ctx wins over hf_config_overrides for the resolved field.
hf_config_overrides = dict(hf_config_overrides or {})
hf_config_overrides[_n_ctx_field] = n_ctx
if hf_config_overrides:
hf_config.__dict__.update(hf_config_overrides)
tl_config = map_default_transformer_lens_config(hf_config)
Expand Down Expand Up @@ -409,13 +463,46 @@ def boot(
with contextlib.redirect_stdout(None):
hf_model = model_class.from_config(hf_config, **from_config_kwargs)
else:
hf_model = model_class.from_pretrained(model_name, **model_kwargs)
try:
hf_model = model_class.from_pretrained(model_name, **model_kwargs)
except RuntimeError as e:
# #5: HF refuses to load when positional-weight shapes don't match.
# If the user requested an n_ctx that conflicts with the saved weights
# (common for learned-pos-embed models like GPT-2), re-raise with a
# clearer message pointing them at the likely cause.
if n_ctx is not None and "ignore_mismatched_sizes" in str(e):
raise RuntimeError(
f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained "
f"weights' positional-embedding shape does not match the requested "
f"context length. This affects models with learned positional "
f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's "
f"default n_ctx, (2) pass load_weights=False if you only need "
f"config inspection, or (3) choose a rotary-embedding model "
f"(e.g. Llama, Mistral) which supports n_ctx changes without "
f"weight mismatch."
) from e
raise
if device is not None:
hf_model = hf_model.to(device)
# Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq)
for param in hf_model.parameters():
if param.is_floating_point() and param.dtype != dtype:
param.data = param.data.to(dtype=dtype)
# #7: Verify the n_ctx override actually took effect on the loaded model.
# If HF's config class silently dropped or normalized the value, warn so
# the user doesn't get misled into thinking longer sequences are supported.
if n_ctx is not None and _n_ctx_field is not None and hf_model is not None:
_actual = getattr(hf_model.config, _n_ctx_field, None)
if _actual != n_ctx:
logging.warning(
"n_ctx=%d was requested but hf_model.config.%s=%s after load. "
"The override may not have taken effect; the model may not "
"accept sequences longer than %s.",
n_ctx,
_n_ctx_field,
_actual,
_actual,
)
adapter.prepare_model(hf_model)
tokenizer = tokenizer
default_padding_side = getattr(adapter.cfg, "default_padding_side", None)
Expand Down
Loading