Skip to content

Make Gemma4ClippableLinear inherit from nn.Linear for PEFT/LoRA compatibility#45388

Closed
albertorkive wants to merge 2 commits intohuggingface:mainfrom
albertorkive:gemma4-clippable-linear-lora
Closed

Make Gemma4ClippableLinear inherit from nn.Linear for PEFT/LoRA compatibility#45388
albertorkive wants to merge 2 commits intohuggingface:mainfrom
albertorkive:gemma4-clippable-linear-lora

Conversation

@albertorkive
Copy link
Copy Markdown

What does this PR do?

Makes Gemma4ClippableLinear inherit from nn.Linear instead of wrapping one via composition, enabling PEFT/LoRA to discover and target vision/audio encoder layers.

Problem: PEFT's LoRA module discovery uses isinstance(module, nn.Linear) to find targetable layers. The current Gemma4ClippableLinear subclasses nn.Module and stores an internal self.linear = nn.Linear(...), so PEFT skips all vision and audio encoder projections (q_proj, k_proj, v_proj, o_proj, ffw layers). Users cannot fine-tune the Gemma4 vision tower with LoRA.

Fix:

  • Change Gemma4ClippableLinear to inherit from nn.Linear directly
  • Weight lives as self.weight (standard nn.Linear) instead of self.linear.weight
  • Clipping behavior is fully preserved
  • Add _remap_legacy_keys state dict pre-hook for backward compatibility with existing checkpoints that store weight under "linear.weight"
  • Update weight converter to use new key names
  • Fix three .linear.weight references in forward methods

Backward compatibility: Existing checkpoints with linear.weight keys load correctly via the pre-hook remap. Verified with strict=True:

# Old checkpoint format
old_sd = {"linear.weight": ..., "input_min": ..., ...}

# Loads into new class without errors
new_module.load_state_dict(old_sd, strict=True)  # works

Files changed

  • src/transformers/models/gemma4/modular_gemma4.py — source of truth
  • src/transformers/models/gemma4/modeling_gemma4.py — generated, matching changes
  • src/transformers/models/gemma4/convert_gemma4_weights.py — updated key names

How to reproduce the bug

from transformers import Gemma4ForConditionalGeneration
import torch.nn as nn

model = Gemma4ForConditionalGeneration.from_pretrained("google/gemma-4-12b-it")

# These are all False — PEFT can't target them
for name, mod in model.named_modules():
    if "vision_tower" in name and hasattr(mod, "linear"):
        print(f"{name}: isinstance(nn.Linear) = {isinstance(mod, nn.Linear)}")
        # Prints: False for every ClippableLinear module

After this PR, all Gemma4ClippableLinear modules pass isinstance(mod, nn.Linear).

…tibility

Gemma4ClippableLinear previously subclassed nn.Module and wrapped an
internal nn.Linear via composition. This prevented PEFT/LoRA from
discovering these layers since it uses isinstance(module, nn.Linear).

Change ClippableLinear to inherit from nn.Linear directly, preserving
the optional input/output clamping behavior. Add a state dict pre-hook
to remap legacy "linear.weight" keys from existing checkpoints to the
new "weight" key for backward compatibility.

Also update the weight converter and fix three .linear.weight references
in forward methods.
@Rocketknight1
Copy link
Copy Markdown
Member

cc @ArthurZucker @Cyrilvallez

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM I don't know why this was not done earlier (cc @Cyrilvallez as you worked on this might have been a reason?)

Comment on lines +155 to +161
@staticmethod
def _remap_legacy_keys(state_dict, prefix, *args, **kwargs):
old_key = prefix + "linear.weight"
new_key = prefix + "weight"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done in the conversion_mapping with our WeightRenaming API!

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better!

]

mapping["gemma4"] = [
WeightRenaming(r"\.linear\.weight", ".weight"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid if all layers used this! I did not check but we might need restriction on layer path?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No other module in the Gemma3/3n/4 tree uses self.linear as an attribute name. Should be safe.

@Cyrilvallez
Copy link
Copy Markdown
Member

Unfortunately, we cannot do that as it fully breaks quantization! Quantization methods replace nn.Linear modules with their own, and here it would replace them but then the custom forward will added clipping would be fully lost!

@Cyrilvallez
Copy link
Copy Markdown
Member

If you want to use peft, you need to explicitly set which modules you want to target manually. I think we had internal discussions about making that easier, cc @merveenoyan @BenjaminBossan

@BenjaminBossan
Copy link
Copy Markdown
Member

Exactly as Cyril said, it's a matter of setting the correct target modules. Changing the parent class is not the solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants