From d082e86ce8a5d7f5a5222a09db3705b5bc2ca27a Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 29 Apr 2026 17:08:22 -0500 Subject: [PATCH 1/2] Fixed Quantization bug in TransformerLens 3.0 --- .../model_bridge/test_bridge_integration.py | 82 +++++++++++++++++++ .../generalized_components/attention.py | 13 ++- .../generalized_components/base.py | 13 ++- 3 files changed, 100 insertions(+), 8 deletions(-) diff --git a/tests/integration/model_bridge/test_bridge_integration.py b/tests/integration/model_bridge/test_bridge_integration.py index d9b11228c..1fe3dc99e 100644 --- a/tests/integration/model_bridge/test_bridge_integration.py +++ b/tests/integration/model_bridge/test_bridge_integration.py @@ -718,6 +718,88 @@ def hook_fn(grad, hook=None): assert hook_called["bridge"], "TransformerBridge backward hook should now be called correctly" +def test_AttentionBridge_preserves_fp_input_when_first_param_is_quantized(): + """Bridge must not cast fp inputs to integer storage dtype. + + Regression for an AttentionBridge / GeneralizedComponent bug where + `target_dtype = next(parameters()).dtype` returned the storage dtype of + quantized weights (uint8 for BnB Params4bit, int32 for GPTQ, etc.). When + the first parameter happened to be quantized, bridge cast fp32 hidden_states + to that integer dtype before passing them to HF — destroying precision and + producing gibberish logits on every quantized model. + + Fakes a "quantized first parameter" by replacing q_proj.weight with a + uint8 tensor, then runs a forward and asserts the input the original + component receives is still floating-point. + """ + from transformer_lens.model_bridge.generalized_components.attention import ( + AttentionBridge, + ) + + # Use tiny Mistral — it's a plain AttentionBridge (not JointQKV). + bridge: TransformerBridge = TransformerBridge.boot_transformers( # type: ignore + "trl-internal-testing/tiny-MistralForCausalLM-0.2", device="cpu" + ) + + attn_bridge = bridge.blocks[0].attn # type: ignore[attr-defined] + assert type(attn_bridge).__name__ == "AttentionBridge", ( + f"Expected plain AttentionBridge, got {type(attn_bridge).__name__}" + ) + assert isinstance(attn_bridge, AttentionBridge) + + original = attn_bridge.original_component + assert original is not None, "AttentionBridge.original_component not set" + + # Fake-quantize q_proj to uint8 storage — mirrors BnB Params4bit. + fp_weight = original.q_proj.weight + original.q_proj.weight = torch.nn.Parameter( + torch.zeros(fp_weight.shape, dtype=torch.uint8), requires_grad=False + ) + assert next(original.parameters()).dtype == torch.uint8, ( + "Test setup: first param should be uint8 to trigger the bug condition" + ) + + # Capture what dtype reaches the original component's forward. + received_dtype: list = [] + orig_forward = original.forward + + def capture(*args, **kwargs): + if "hidden_states" in kwargs: + received_dtype.append(kwargs["hidden_states"].dtype) + elif args: + received_dtype.append(args[0].dtype) + # Don't actually run forward — fake-quantized weight would error. + # Return a shape-compatible dummy. HF Mistral attention returns a tuple. + bsz, seq, d_model = ( + kwargs.get("hidden_states", args[0] if args else None) + ).shape + n_heads = bridge.cfg.n_heads # type: ignore[attr-defined] + return ( + torch.zeros(bsz, seq, d_model, dtype=torch.float32), + torch.zeros(bsz, n_heads, seq, seq, dtype=torch.float32), + ) + + original.forward = capture # type: ignore[method-assign] + try: + test_input = torch.tensor([[1, 2, 3, 4, 5]]) + with torch.no_grad(): + try: + bridge(test_input) + except Exception: + pass # downstream may fail; we only care what reached attn forward + finally: + original.forward = orig_forward # type: ignore[method-assign] + original.q_proj.weight = fp_weight + + assert len(received_dtype) > 0, "Original attention forward never called" + for dt in received_dtype: + assert dt.is_floating_point, ( + f"Bridge passed dtype={dt} to original attention forward, but it must be " + f"floating point. Regression of the AttentionBridge dtype-cast bug — " + f"target_dtype must skip non-fp (quantized-storage) parameters." + ) + + @pytest.mark.skipif(bool(os.getenv("CI")), reason="Skip Gemma2 test in CI to avoid timeout") def test_TransformerBridge_gemma2_forward(): """Test that TransformerBridge properly handles Gemma2's position_embeddings. diff --git a/transformer_lens/model_bridge/generalized_components/attention.py b/transformer_lens/model_bridge/generalized_components/attention.py index 34da4e280..22504b294 100644 --- a/transformer_lens/model_bridge/generalized_components/attention.py +++ b/transformer_lens/model_bridge/generalized_components/attention.py @@ -622,11 +622,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( f"Original component not set for {self.name}. Call set_original_component() first." ) + # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, + # HQQ, torchao) are stored in integer dtypes and dequantized internally + # during matmul. The compute dtype must come from a fp parameter; casting + # fp inputs to an integer storage dtype destroys precision. target_dtype = None - try: - target_dtype = next(self.original_component.parameters()).dtype - except StopIteration: - pass + for p in self.original_component.parameters(): + if not p.dtype.is_floating_point: + continue + target_dtype = p.dtype + break if "query_input" in kwargs: hooked = self.hook_in(kwargs["query_input"]) if ( diff --git a/transformer_lens/model_bridge/generalized_components/base.py b/transformer_lens/model_bridge/generalized_components/base.py index 20e44fbbb..ae6787b67 100644 --- a/transformer_lens/model_bridge/generalized_components/base.py +++ b/transformer_lens/model_bridge/generalized_components/base.py @@ -274,11 +274,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( f"Original component not set for {self.name}. Call set_original_component() first." ) + # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, + # HQQ, torchao) are stored in integer dtypes and dequantized internally + # during matmul. The compute dtype must come from a fp parameter; casting + # fp inputs to an integer storage dtype destroys precision. target_dtype = None - try: - target_dtype = next(original_component.parameters()).dtype - except StopIteration: - pass + for p in original_component.parameters(): + if not p.dtype.is_floating_point: + continue + target_dtype = p.dtype + break input_arg_names = [ "input", "hidden_states", From ebf61d7b07b7d8e78e8930750e752b31b9968dca Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 29 Apr 2026 17:26:45 -0500 Subject: [PATCH 2/2] Format fixes --- .../model_bridge/test_bridge_integration.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/integration/model_bridge/test_bridge_integration.py b/tests/integration/model_bridge/test_bridge_integration.py index 1fe3dc99e..febe910b0 100644 --- a/tests/integration/model_bridge/test_bridge_integration.py +++ b/tests/integration/model_bridge/test_bridge_integration.py @@ -742,9 +742,9 @@ def test_AttentionBridge_preserves_fp_input_when_first_param_is_quantized(): ) attn_bridge = bridge.blocks[0].attn # type: ignore[attr-defined] - assert type(attn_bridge).__name__ == "AttentionBridge", ( - f"Expected plain AttentionBridge, got {type(attn_bridge).__name__}" - ) + assert ( + type(attn_bridge).__name__ == "AttentionBridge" + ), f"Expected plain AttentionBridge, got {type(attn_bridge).__name__}" assert isinstance(attn_bridge, AttentionBridge) original = attn_bridge.original_component @@ -755,9 +755,9 @@ def test_AttentionBridge_preserves_fp_input_when_first_param_is_quantized(): original.q_proj.weight = torch.nn.Parameter( torch.zeros(fp_weight.shape, dtype=torch.uint8), requires_grad=False ) - assert next(original.parameters()).dtype == torch.uint8, ( - "Test setup: first param should be uint8 to trigger the bug condition" - ) + assert ( + next(original.parameters()).dtype == torch.uint8 + ), "Test setup: first param should be uint8 to trigger the bug condition" # Capture what dtype reaches the original component's forward. received_dtype: list = [] @@ -770,9 +770,7 @@ def capture(*args, **kwargs): received_dtype.append(args[0].dtype) # Don't actually run forward — fake-quantized weight would error. # Return a shape-compatible dummy. HF Mistral attention returns a tuple. - bsz, seq, d_model = ( - kwargs.get("hidden_states", args[0] if args else None) - ).shape + bsz, seq, d_model = (kwargs.get("hidden_states", args[0] if args else None)).shape n_heads = bridge.cfg.n_heads # type: ignore[attr-defined] return ( torch.zeros(bsz, seq, d_model, dtype=torch.float32),