diff --git a/demos/T5.ipynb b/demos/T5.ipynb index 1d225da96..fd4bc319e 100644 --- a/demos/T5.ipynb +++ b/demos/T5.ipynb @@ -88,14 +88,14 @@ "generated token: \",\", token id: 6\n", "generated token: \"comment\", token id: 1670\n", "generated token: \"\", token id: 3\n", - "generated token: \"\u00eates\", token id: 6738\n", + "generated token: \"êtes\", token id: 6738\n", "generated token: \"-\", token id: 18\n", "generated token: \"vous\", token id: 3249\n", "generated token: \"\", token id: 3\n", "generated token: \"?\", token id: 58\n", "generated token: \"\", token id: 1\n", "translate English to French: Hello, how are you? \n", - " Bonjour, comment \u00eates-vous?\n" + " Bonjour, comment êtes-vous?\n" ] } ], @@ -206,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2026-03-05T18:28:00.478310Z", @@ -215,21 +215,8 @@ "shell.execute_reply": "2026-03-05T18:28:00.629766Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Hallo, magst du Bananen?\n" - ] - } - ], - "source": [ - "prompt=\"translate English to German: Hello, do you like bananas?\"\n", - "\n", - "output = model.generate(prompt, do_sample=False, max_new_tokens=20)\n", - "print(output)" - ] + "outputs": [], + "source": "prompt=\"translate English to German: Hello, do you like bananas?\"\n\noutput = model.generate(prompt, do_sample=False, max_new_tokens=20, verbose=False)\nprint(output)" }, { "cell_type": "markdown", @@ -928,7 +915,7 @@ "outputs": [], "source": [ "encoder_attn_pattern = cache[\"encoder_blocks.0.attn.hook_pattern\"]\n", - "input_str_tokens = [w.lstrip(\"\u2581\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]" + "input_str_tokens = [w.lstrip(\"▁\") for w in tokenizer.convert_ids_to_tokens(input_ids[0])]" ] }, { @@ -993,14 +980,14 @@ "data": { "text/plain": [ "['',\n", - " '\u2581Bonjour',\n", + " '▁Bonjour',\n", " ',',\n", - " '\u2581comment',\n", - " '\u2581',\n", - " '\u00eates',\n", + " '▁comment',\n", + " '▁',\n", + " 'êtes',\n", " '-',\n", " 'vous',\n", - " '\u2581',\n", + " '▁',\n", " '?',\n", " '']" ] @@ -1143,4 +1130,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/tests/acceptance/model_bridge/test_n_ctx_override.py b/tests/acceptance/model_bridge/test_n_ctx_override.py new file mode 100644 index 000000000..be35474e6 --- /dev/null +++ b/tests/acceptance/model_bridge/test_n_ctx_override.py @@ -0,0 +1,117 @@ +"""Tests for the n_ctx override parameter on TransformerBridge.boot_transformers(). + +Uses load_weights=False so we can verify config plumbing without fighting HF's +weight-loading checks. Models with learned positional embeddings (e.g. GPT-2) +cannot have their n_ctx reduced at weight-load time — only rotary models can +freely resize. These tests verify the config is written correctly; users are +responsible for choosing n_ctx values their model supports. +""" + +import logging + +import pytest + +from transformer_lens.model_bridge import TransformerBridge + + +def test_n_ctx_override_writes_to_correct_hf_field(): + """For GPT-2 the field is n_positions — overriding n_ctx should update it.""" + bridge = TransformerBridge.boot_transformers( + "gpt2", device="cpu", n_ctx=256, load_weights=False + ) + assert bridge.cfg.n_ctx == 256 + assert bridge.original_model.config.n_positions == 256 + + +def test_n_ctx_default_uses_model_max(): + """Without an override, cfg.n_ctx reflects the HF config's value.""" + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu", load_weights=False) + # GPT-2's n_positions default is 1024 + assert bridge.cfg.n_ctx == 1024 + + +def test_n_ctx_warns_when_above_default(caplog): + """Overriding n_ctx above the model default should emit a logging.warning.""" + with caplog.at_level(logging.WARNING): + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=2048, load_weights=False) + assert any( + "larger than the model's default context length" in rec.message for rec in caplog.records + ) + + +def test_n_ctx_combined_with_hf_config_overrides(): + """Explicit n_ctx should take precedence over hf_config_overrides for that field.""" + bridge = TransformerBridge.boot_transformers( + "gpt2", + device="cpu", + n_ctx=256, + hf_config_overrides={"n_positions": 512}, # should be overridden by n_ctx=256 + load_weights=False, + ) + assert bridge.cfg.n_ctx == 256 + + +# --- Coverage for code-review items #2, #4, #5, #7 --- + + +def test_n_ctx_zero_raises_value_error(): + """#2: n_ctx must be positive; zero should raise ValueError.""" + with pytest.raises(ValueError, match="positive integer"): + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=0, load_weights=False) + + +def test_n_ctx_negative_raises_value_error(): + """#2: n_ctx must be positive; negative should raise ValueError.""" + with pytest.raises(ValueError, match="positive integer"): + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=-1, load_weights=False) + + +def test_n_ctx_conflict_with_hf_config_overrides_warns(caplog): + """#4: When both n_ctx and the same hf_config_overrides field are set with different values, + a warning should be emitted explaining that n_ctx wins.""" + with caplog.at_level(logging.WARNING): + TransformerBridge.boot_transformers( + "gpt2", + device="cpu", + n_ctx=256, + hf_config_overrides={"n_positions": 512}, + load_weights=False, + ) + assert any( + "Both n_ctx=256 and hf_config_overrides['n_positions']" in rec.message + and "takes precedence" in rec.message + for rec in caplog.records + ) + + +def test_n_ctx_no_conflict_when_values_match(caplog): + """#4: If n_ctx and hf_config_overrides agree on the value, no conflict warning is emitted.""" + with caplog.at_level(logging.WARNING): + TransformerBridge.boot_transformers( + "gpt2", + device="cpu", + n_ctx=256, + hf_config_overrides={"n_positions": 256}, # same as n_ctx + load_weights=False, + ) + assert not any("takes precedence" in rec.message for rec in caplog.records) + + +def test_n_ctx_shrink_with_load_weights_gives_clear_error(): + """#5: Shrinking a learned-pos-embed model's n_ctx at weight-load time should raise + with a message explaining the cause and suggesting alternatives.""" + with pytest.raises(RuntimeError) as exc_info: + TransformerBridge.boot_transformers("gpt2", device="cpu", n_ctx=256, load_weights=True) + err = str(exc_info.value) + assert "n_ctx=256" in err + assert "learned positional embeddings" in err or "load_weights=False" in err + + +def test_n_ctx_override_verified_on_loaded_model(): + """#7: After load, the override should be visible on hf_model.config so users + can trust that the longer/shorter context is actually in effect.""" + bridge = TransformerBridge.boot_transformers( + "gpt2", device="cpu", n_ctx=2048, load_weights=False + ) + # The override persisted through model construction + assert bridge.original_model.config.n_positions == 2048 diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index cfac0b51c..72eabe0c4 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -167,6 +167,7 @@ def boot_transformers( trust_remote_code: bool = False, model_class: Optional[type] = None, hf_model: Optional[Any] = None, + n_ctx: Optional[int] = None, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -183,6 +184,9 @@ def boot_transformers( hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored. + n_ctx: Optional context length override. Writes to the appropriate HF config field + for this model automatically (callers don't need to know the field name). + Warns if larger than the model's default context length. Returns: The bridge to the loaded model. @@ -199,6 +203,7 @@ def boot_transformers( trust_remote_code=trust_remote_code, model_class=model_class, hf_model=hf_model, + n_ctx=n_ctx, ) @property diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 0169da4dc..99b90a968 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -286,6 +286,7 @@ def boot( trust_remote_code: bool = False, model_class: Any | None = None, hf_model: Any | None = None, + n_ctx: int | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. @@ -302,6 +303,11 @@ def boot( hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored. + n_ctx: Optional context length override. The bridge normally uses the model's documented + max context from the HF config. Setting this writes to whichever HF field the model + uses (n_positions / max_position_embeddings / etc.), so callers don't need to know + the field name. If larger than the model's default, a warning is emitted — quality + may degrade past the trained length for rotary models. Returns: The bridge to the loaded model. @@ -323,6 +329,54 @@ def boot( trust_remote_code=trust_remote_code, token=_hf_token, ) + _n_ctx_field: str | None = None + if n_ctx is not None: + # Validation (#2): reject non-positive values before doing anything else. + if n_ctx <= 0: + raise ValueError(f"n_ctx must be a positive integer, got n_ctx={n_ctx}.") + # Resolve n_ctx to whichever HF config field this model uses. Mirrors + # the order in map_default_transformer_lens_config so the TL config + # derivation picks up the override. + for _field in ( + "n_positions", + "max_position_embeddings", + "max_context_length", + "max_length", + "seq_length", + ): + if hasattr(hf_config, _field): + _n_ctx_field = _field + break + if _n_ctx_field is None: + raise ValueError( + f"Cannot apply n_ctx={n_ctx}: no recognized context-length field on " + f"HF config for {model_name}. Use hf_config_overrides instead." + ) + _default_n_ctx = getattr(hf_config, _n_ctx_field) + if _default_n_ctx is not None and n_ctx > _default_n_ctx: + logging.warning( + "Setting n_ctx=%d which is larger than the model's default " + "context length of %d. The model was not trained on sequences " + "this long and may produce unreliable results (especially for " + "rotary models without RoPE scaling).", + n_ctx, + _default_n_ctx, + ) + # Conflict detection (#4): warn if the caller also set the same field + # via hf_config_overrides — explicit n_ctx wins but users should know. + if hf_config_overrides and _n_ctx_field in hf_config_overrides: + _conflicting_value = hf_config_overrides[_n_ctx_field] + if _conflicting_value != n_ctx: + logging.warning( + "Both n_ctx=%d and hf_config_overrides['%s']=%s were provided. " + "The explicit n_ctx takes precedence.", + n_ctx, + _n_ctx_field, + _conflicting_value, + ) + # Explicit n_ctx wins over hf_config_overrides for the resolved field. + hf_config_overrides = dict(hf_config_overrides or {}) + hf_config_overrides[_n_ctx_field] = n_ctx if hf_config_overrides: hf_config.__dict__.update(hf_config_overrides) tl_config = map_default_transformer_lens_config(hf_config) @@ -409,13 +463,46 @@ def boot( with contextlib.redirect_stdout(None): hf_model = model_class.from_config(hf_config, **from_config_kwargs) else: - hf_model = model_class.from_pretrained(model_name, **model_kwargs) + try: + hf_model = model_class.from_pretrained(model_name, **model_kwargs) + except RuntimeError as e: + # #5: HF refuses to load when positional-weight shapes don't match. + # If the user requested an n_ctx that conflicts with the saved weights + # (common for learned-pos-embed models like GPT-2), re-raise with a + # clearer message pointing them at the likely cause. + if n_ctx is not None and "ignore_mismatched_sizes" in str(e): + raise RuntimeError( + f"Failed to load {model_name} with n_ctx={n_ctx}: the pretrained " + f"weights' positional-embedding shape does not match the requested " + f"context length. This affects models with learned positional " + f"embeddings (e.g. GPT-2, OPT). Options: (1) use the model's " + f"default n_ctx, (2) pass load_weights=False if you only need " + f"config inspection, or (3) choose a rotary-embedding model " + f"(e.g. Llama, Mistral) which supports n_ctx changes without " + f"weight mismatch." + ) from e + raise if device is not None: hf_model = hf_model.to(device) # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq) for param in hf_model.parameters(): if param.is_floating_point() and param.dtype != dtype: param.data = param.data.to(dtype=dtype) + # #7: Verify the n_ctx override actually took effect on the loaded model. + # If HF's config class silently dropped or normalized the value, warn so + # the user doesn't get misled into thinking longer sequences are supported. + if n_ctx is not None and _n_ctx_field is not None and hf_model is not None: + _actual = getattr(hf_model.config, _n_ctx_field, None) + if _actual != n_ctx: + logging.warning( + "n_ctx=%d was requested but hf_model.config.%s=%s after load. " + "The override may not have taken effect; the model may not " + "accept sequences longer than %s.", + n_ctx, + _n_ctx_field, + _actual, + _actual, + ) adapter.prepare_model(hf_model) tokenizer = tokenizer default_padding_side = getattr(adapter.cfg, "default_padding_side", None)