Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/integration/model_bridge/test_bridge_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,86 @@ def hook_fn(grad, hook=None):
assert hook_called["bridge"], "TransformerBridge backward hook should now be called correctly"


def test_AttentionBridge_preserves_fp_input_when_first_param_is_quantized():
"""Bridge must not cast fp inputs to integer storage dtype.

Regression for an AttentionBridge / GeneralizedComponent bug where
`target_dtype = next(parameters()).dtype` returned the storage dtype of
quantized weights (uint8 for BnB Params4bit, int32 for GPTQ, etc.). When
the first parameter happened to be quantized, bridge cast fp32 hidden_states
to that integer dtype before passing them to HF — destroying precision and
producing gibberish logits on every quantized model.

Fakes a "quantized first parameter" by replacing q_proj.weight with a
uint8 tensor, then runs a forward and asserts the input the original
component receives is still floating-point.
"""
from transformer_lens.model_bridge.generalized_components.attention import (
AttentionBridge,
)

# Use tiny Mistral — it's a plain AttentionBridge (not JointQKV).
bridge: TransformerBridge = TransformerBridge.boot_transformers( # type: ignore
"trl-internal-testing/tiny-MistralForCausalLM-0.2", device="cpu"
)

attn_bridge = bridge.blocks[0].attn # type: ignore[attr-defined]
assert (
type(attn_bridge).__name__ == "AttentionBridge"
), f"Expected plain AttentionBridge, got {type(attn_bridge).__name__}"
assert isinstance(attn_bridge, AttentionBridge)

original = attn_bridge.original_component
assert original is not None, "AttentionBridge.original_component not set"

# Fake-quantize q_proj to uint8 storage — mirrors BnB Params4bit.
fp_weight = original.q_proj.weight
original.q_proj.weight = torch.nn.Parameter(
torch.zeros(fp_weight.shape, dtype=torch.uint8), requires_grad=False
)
assert (
next(original.parameters()).dtype == torch.uint8
), "Test setup: first param should be uint8 to trigger the bug condition"

# Capture what dtype reaches the original component's forward.
received_dtype: list = []
orig_forward = original.forward

def capture(*args, **kwargs):
if "hidden_states" in kwargs:
received_dtype.append(kwargs["hidden_states"].dtype)
elif args:
received_dtype.append(args[0].dtype)
# Don't actually run forward — fake-quantized weight would error.
# Return a shape-compatible dummy. HF Mistral attention returns a tuple.
bsz, seq, d_model = (kwargs.get("hidden_states", args[0] if args else None)).shape
n_heads = bridge.cfg.n_heads # type: ignore[attr-defined]
return (
torch.zeros(bsz, seq, d_model, dtype=torch.float32),
torch.zeros(bsz, n_heads, seq, seq, dtype=torch.float32),
)

original.forward = capture # type: ignore[method-assign]
try:
test_input = torch.tensor([[1, 2, 3, 4, 5]])
with torch.no_grad():
try:
bridge(test_input)
except Exception:
pass # downstream may fail; we only care what reached attn forward
finally:
original.forward = orig_forward # type: ignore[method-assign]
original.q_proj.weight = fp_weight

assert len(received_dtype) > 0, "Original attention forward never called"
for dt in received_dtype:
assert dt.is_floating_point, (
f"Bridge passed dtype={dt} to original attention forward, but it must be "
f"floating point. Regression of the AttentionBridge dtype-cast bug — "
f"target_dtype must skip non-fp (quantized-storage) parameters."
)


@pytest.mark.skipif(bool(os.getenv("CI")), reason="Skip Gemma2 test in CI to avoid timeout")
def test_TransformerBridge_gemma2_forward():
"""Test that TransformerBridge properly handles Gemma2's position_embeddings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
f"Original component not set for {self.name}. Call set_original_component() first."
)
# Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
# HQQ, torchao) are stored in integer dtypes and dequantized internally
# during matmul. The compute dtype must come from a fp parameter; casting
# fp inputs to an integer storage dtype destroys precision.
target_dtype = None
try:
target_dtype = next(self.original_component.parameters()).dtype
except StopIteration:
pass
for p in self.original_component.parameters():
if not p.dtype.is_floating_point:
continue
target_dtype = p.dtype
break
if "query_input" in kwargs:
hooked = self.hook_in(kwargs["query_input"])
if (
Expand Down
13 changes: 9 additions & 4 deletions transformer_lens/model_bridge/generalized_components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError(
f"Original component not set for {self.name}. Call set_original_component() first."
)
# Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
# HQQ, torchao) are stored in integer dtypes and dequantized internally
# during matmul. The compute dtype must come from a fp parameter; casting
# fp inputs to an integer storage dtype destroys precision.
target_dtype = None
try:
target_dtype = next(original_component.parameters()).dtype
except StopIteration:
pass
for p in original_component.parameters():
if not p.dtype.is_floating_point:
continue
target_dtype = p.dtype
break
input_arg_names = [
"input",
"hidden_states",
Expand Down
Loading