Close form optimisation for block linear -> activation -> linear#231
Close form optimisation for block linear -> activation -> linear#231Edarfix wants to merge 20 commits intogrowingnet:mainfrom
Conversation
… into layer-initialization
Codecov Report❌ Patch coverage is
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Adds a new closed-form optimizer for Linear -> activation -> Linear blocks (alternating closed-form updates for the two linear layers), along with a dedicated unit test suite, and updates TensorStatistic typing / naming to remove a Pylance warning.
Changes:
- Add
src/gromo/utils/optimisation.pyimplementing the closed-form block optimizer, statistics collection, and helper solvers. - Add
tests/test_optimisation.pycovering supported block layouts, solver branches, early-stopping paths, and end-to-end teacher/student improvement. - Fix naming/typing around
TensorStatisticWithEstimationErrorand broaden the accepted update function signature.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_tensor_statistics.py | Renames the estimation-error statistic class usage to the corrected spelling. |
| tests/test_optimisation.py | New tests for the closed-form optimizer module and its helper functions. |
| src/gromo/utils/tensor_statistic.py | Adds type aliases for update functions and renames TensorStatisticWithEstimationError. |
| src/gromo/utils/optimisation.py | New closed-form optimizer implementation for 2-layer linear blocks. |
| docs/source/whats_new.rst | Adds an entry (currently contains unresolved merge-conflict markers). |
| docs/source/sphinxext/gh_substitutions.py | Adjusts gh_role definition (currently introduces mutable-default risk). |
| docs/source/conf.py | Minor doc config formatting / noqa cleanup. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| update_function=lambda batch_tensor, batch_weight: ( | ||
| (batch_tensor * batch_weight.unsqueeze(1)).sum(dim=0), | ||
| float(batch_weight.sum().item()), | ||
| ), |
There was a problem hiding this comment.
These weighted-statistic update functions return float(batch_weight.sum().item()) as the sample count/mass. However TensorStatistic/TensorStatisticWithEstimationError currently types and stores samples as an int and expects nb_sample to be an integer count. Either change the statistics helpers to support a float-valued sample mass (and type it accordingly) or return an int count and track weight mass separately; otherwise typing (Pyright/Pylance) and the semantics of samples become inconsistent.
| StatisticUpdateResult: TypeAlias = tuple[torch.Tensor, int] | ||
| StatisticUpdateFunction: TypeAlias = Callable[..., StatisticUpdateResult] |
There was a problem hiding this comment.
StatisticUpdateResult fixes the returned sample count to int, but the new closed-form optimizer uses TensorStatisticWithEstimationError to accumulate weighted moments where the natural normalization is a float-valued sample mass. To avoid type-checker errors and clarify semantics, consider changing this alias (and TensorStatistic.samples) to accept float (or float | int).
There was a problem hiding this comment.
@copilot apply changes based on this feedback
| def optimize( | ||
| self, | ||
| block: nn.Module | TwoLayerLinearBlockView, | ||
| dataloader: list[Any] | Any, | ||
| ) -> OptimizationResult: |
There was a problem hiding this comment.
ClosedFormBlockOptimizer.optimize iterates over dataloader multiple times (once per alternating step, plus a final pass). If callers pass an iterator/generator (single-pass), subsequent iterations will silently see no batches and raise/produce incorrect updates. Consider requiring a re-iterable Iterable/DataLoader (and document it), or materialize/copy batches when a one-shot iterator is provided.
There was a problem hiding this comment.
@copilot apply changes based on this feedback
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
sylvchev
left a comment
There was a problem hiding this comment.
An important PR, we could iterate in the next push to update the code if needed, but it seems mature enough. On top of the tests, it could bé interesting to add an example demonstrating how to use the code.
Add optimisation.py and its corresponding tests.
This PR introduces closed-form optimization utilities for Linear -> activation -> Linear blocks, along with a dedicated test suite covering the main supported layouts and optimization paths.
It also includes a small typing update in TensorStatistic to suppress a Pylance error by broadening the accepted update_function callable signature.