Make Gemma4ClippableLinear inherit from nn.Linear for PEFT/LoRA compatibility#45388
Make Gemma4ClippableLinear inherit from nn.Linear for PEFT/LoRA compatibility#45388albertorkive wants to merge 2 commits intohuggingface:mainfrom
Conversation
…tibility Gemma4ClippableLinear previously subclassed nn.Module and wrapped an internal nn.Linear via composition. This prevented PEFT/LoRA from discovering these layers since it uses isinstance(module, nn.Linear). Change ClippableLinear to inherit from nn.Linear directly, preserving the optional input/output clamping behavior. Add a state dict pre-hook to remap legacy "linear.weight" keys from existing checkpoints to the new "weight" key for backward compatibility. Also update the weight converter and fix three .linear.weight references in forward methods.
ArthurZucker
left a comment
There was a problem hiding this comment.
SGTM I don't know why this was not done earlier (cc @Cyrilvallez as you worked on this might have been a reason?)
| @staticmethod | ||
| def _remap_legacy_keys(state_dict, prefix, *args, **kwargs): | ||
| old_key = prefix + "linear.weight" | ||
| new_key = prefix + "weight" | ||
| if old_key in state_dict: | ||
| state_dict[new_key] = state_dict.pop(old_key) | ||
|
|
There was a problem hiding this comment.
This should be done in the conversion_mapping with our WeightRenaming API!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma4 |
| ] | ||
|
|
||
| mapping["gemma4"] = [ | ||
| WeightRenaming(r"\.linear\.weight", ".weight"), |
There was a problem hiding this comment.
Valid if all layers used this! I did not check but we might need restriction on layer path?
There was a problem hiding this comment.
No other module in the Gemma3/3n/4 tree uses self.linear as an attribute name. Should be safe.
|
Unfortunately, we cannot do that as it fully breaks quantization! Quantization methods replace |
|
If you want to use |
|
Exactly as Cyril said, it's a matter of setting the correct target modules. Changing the parent class is not the solution. |
What does this PR do?
Makes
Gemma4ClippableLinearinherit fromnn.Linearinstead of wrapping one via composition, enabling PEFT/LoRA to discover and target vision/audio encoder layers.Problem: PEFT's LoRA module discovery uses
isinstance(module, nn.Linear)to find targetable layers. The currentGemma4ClippableLinearsubclassesnn.Moduleand stores an internalself.linear = nn.Linear(...), so PEFT skips all vision and audio encoder projections (q_proj, k_proj, v_proj, o_proj, ffw layers). Users cannot fine-tune the Gemma4 vision tower with LoRA.Fix:
Gemma4ClippableLinearto inherit fromnn.Lineardirectlyself.weight(standardnn.Linear) instead ofself.linear.weight_remap_legacy_keysstate dict pre-hook for backward compatibility with existing checkpoints that store weight under"linear.weight".linear.weightreferences in forward methodsBackward compatibility: Existing checkpoints with
linear.weightkeys load correctly via the pre-hook remap. Verified withstrict=True:Files changed
src/transformers/models/gemma4/modular_gemma4.py— source of truthsrc/transformers/models/gemma4/modeling_gemma4.py— generated, matching changessrc/transformers/models/gemma4/convert_gemma4_weights.py— updated key namesHow to reproduce the bug
After this PR, all
Gemma4ClippableLinearmodules passisinstance(mod, nn.Linear).