feat: Neuron initialization and rescaling from Variance Transfert#237
feat: Neuron initialization and rescaling from Variance Transfert#237TheoRudkiewicz wants to merge 20 commits intogrowingnet:mainfrom
Conversation
Codecov Report❌ Patch coverage is
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Adds variance-transfer (VT) weight rescaling and (V,V)/(Z,-Z) neuron pairing to the growing-module extension workflow, with a comprehensive test suite validating correctness and edge cases.
Changes:
- Implemented VT rescaling strategies and neuron pairing utilities in
GrowingModule, and integrated them intocreate_layer_extensions. - Exposed rescaling/pairing controls at the
GrowingBlockcontainer level for block-wide usage and FOGRO-style workflows. - Added an extensive new test module covering smoke, semantic, variance, BatchNorm, edge cases, and standalone method usage.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/gromo/modules/growing_module.py |
Adds apply_rescaling, apply_neuron_pairing, and integrates rescaling/pairing into extension creation. |
src/gromo/containers/growing_block.py |
Threads new rescaling / neuron_pairing options through the block API and adds delegation helpers. |
tests/test_variance_transfer.py |
New tests validating VT rescaling + pairing behavior, including BatchNorm running-stat scaling and edge cases. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.apply_neuron_pairing(neuron_pairing=neuron_pairing) | ||
|
|
There was a problem hiding this comment.
create_layer_extensions(..., neuron_pairing="vv_z_negz") doubles extended_output_layer / extended_input_layer sizes, but nothing updates extension_size bookkeeping (e.g., eigenvalues_extension or a dedicated field). Downstream, apply_change() uses extension_size (or infers it from eigenvalues_extension) to grow sized post_layer_functions (e.g., GrowingBatchNorm / GrowableIdentity). If pairing is enabled and callers keep passing the pre-pairing size (or pass None), the model can end up with mismatched channel counts in post-layer functions.
| self.apply_neuron_pairing(neuron_pairing=neuron_pairing) | |
| self.apply_neuron_pairing(neuron_pairing=neuron_pairing) | |
| self._sync_extension_size_bookkeeping() | |
| def _sync_extension_size_bookkeeping(self) -> None: | |
| """ | |
| Synchronize extension-size bookkeeping with the current extension layers. | |
| Neuron pairing can change the effective number of added neurons by | |
| modifying the extension layers after they have been created. Downstream | |
| code may rely on bookkeeping fields such as ``extension_size`` or infer | |
| the size from ``eigenvalues_extension``, so we keep them aligned with the | |
| actual extension-layer shape here. | |
| """ | |
| def _layer_extension_size(layer: torch.nn.Module | None) -> int | None: | |
| if layer is None: | |
| return None | |
| if hasattr(layer, "out_features"): | |
| return int(layer.out_features) | |
| if hasattr(layer, "out_channels"): | |
| return int(layer.out_channels) | |
| weight = getattr(layer, "weight", None) | |
| if isinstance(weight, torch.Tensor) and weight.ndim > 0: | |
| return int(weight.shape[0]) | |
| return None | |
| extension_size = _layer_extension_size( | |
| getattr(self, "extended_input_layer", None) | |
| ) | |
| if extension_size is None: | |
| previous_module = getattr(self, "previous_module", None) | |
| extension_size = _layer_extension_size( | |
| getattr(previous_module, "extended_output_layer", None) | |
| ) | |
| if extension_size is None: | |
| return | |
| self.extension_size = extension_size | |
| eigenvalues_extension = getattr(self, "eigenvalues_extension", None) | |
| if isinstance(eigenvalues_extension, torch.Tensor): | |
| current_size = int(eigenvalues_extension.numel()) | |
| if current_size == 0 or current_size == extension_size: | |
| return | |
| if extension_size % current_size == 0: | |
| repeat_factor = extension_size // current_size | |
| self.eigenvalues_extension = eigenvalues_extension.repeat( | |
| repeat_factor | |
| ) |
There was a problem hiding this comment.
I think it's not a problem since for now there is no plan to use pairing with eigenvalues extension used.
stephane-rivaud
left a comment
There was a problem hiding this comment.
I did not preview the docs changes nor did I thoroughly investigate the test cases, but the functional changes look good. I approve this PR.
|
Before merging, I need to check that initialisation scale is good even with pairing. |
In
The bottom of the line is that we actually consider a tensor of size |
There was indeed a problem, I should have fixed it in d2d0d41129562d8e45b330377d01eaa79a3f5fa1. |
Yes I chose to ask for growth of |
|
Remark from @alexdavey : we should probably include the "gain" in the re-scaling (even if variance transfer does not do it ?). |
alexdavey
left a comment
There was a problem hiding this comment.
Thanks @TheoRudkiewicz, here are a few comments :)
|
|
||
| * ``"default_vt"`` (Strategy A): beta = sqrt(fan_in_old / fan_in_new), | ||
| alpha = 1 (the previous layer input is not extended). | ||
| * ``"vt_constraint_old_shape"`` (Strategy B): alpha and beta chosen so |
There was a problem hiding this comment.
Why would we want to enforce variance with respect to the previous fan_in? If we consider a sequence of growth steps, the behaviour we get is:
default_vt: Variance\propto 1/current_fan_in(with no training).vt_constraint_new_shape: Variance\propto 1/current_fan_in.vt_constraint_old_shape: Variance\propto 1/penultimate_fan_in.
The behaviour of the last option vt_constraint_old_shape does not really make sense to me.
Also: Should we consider making the gain a parameter in the constraint strategies? So that we can use e.g. kaiming sqrt(2)/fan_in instead?
There was a problem hiding this comment.
Why ?
Short answer it's the one propose by the paper VT.
Long answer: https://theorudkiewicz.github.io/gromo/tech_notes/variance_transfer.html#part-3-combined-analysis
About the gain parameter: probably.
In the end we add |
I leave this for future work as it should be integrated in many places like |
Co-authored-by: Copilot <copilot@github.com>
…ven with pairing
Co-authored-by: Copilot <copilot@github.com>
Adds variance-transfer (VT) weight rescaling and (V,V)/(Z,-Z) neuron pairing to the growing-module extension workflow, with a comprehensive test suite validating correctness and edge cases.
See https://theorudkiewicz.github.io/gromo/tech_notes.html for updated tech notes.