Skip to content

Close form optimisation for block linear -> activation -> linear#231

Open
Edarfix wants to merge 20 commits intogrowingnet:mainfrom
Edarfix:layer-initialization
Open

Close form optimisation for block linear -> activation -> linear#231
Edarfix wants to merge 20 commits intogrowingnet:mainfrom
Edarfix:layer-initialization

Conversation

@Edarfix
Copy link
Copy Markdown
Collaborator

@Edarfix Edarfix commented Mar 24, 2026

Add optimisation.py and its corresponding tests.

This PR introduces closed-form optimization utilities for Linear -> activation -> Linear blocks, along with a dedicated test suite covering the main supported layouts and optimization paths.

It also includes a small typing update in TensorStatistic to suppress a Pylance error by broadening the accepted update_function callable signature.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 31, 2026

Codecov Report

❌ Patch coverage is 97.26027% with 8 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/gromo/utils/optimisation.py 97.19% 6 Missing and 2 partials ⚠️
Flag Coverage Δ
unittests 94.54% <97.26%> (+0.16%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
src/gromo/utils/tensor_statistic.py 100.00% <100.00%> (ø)
src/gromo/utils/optimisation.py 97.19% <97.19%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new closed-form optimizer for Linear -> activation -> Linear blocks (alternating closed-form updates for the two linear layers), along with a dedicated unit test suite, and updates TensorStatistic typing / naming to remove a Pylance warning.

Changes:

  • Add src/gromo/utils/optimisation.py implementing the closed-form block optimizer, statistics collection, and helper solvers.
  • Add tests/test_optimisation.py covering supported block layouts, solver branches, early-stopping paths, and end-to-end teacher/student improvement.
  • Fix naming/typing around TensorStatisticWithEstimationError and broaden the accepted update function signature.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/test_tensor_statistics.py Renames the estimation-error statistic class usage to the corrected spelling.
tests/test_optimisation.py New tests for the closed-form optimizer module and its helper functions.
src/gromo/utils/tensor_statistic.py Adds type aliases for update functions and renames TensorStatisticWithEstimationError.
src/gromo/utils/optimisation.py New closed-form optimizer implementation for 2-layer linear blocks.
docs/source/whats_new.rst Adds an entry (currently contains unresolved merge-conflict markers).
docs/source/sphinxext/gh_substitutions.py Adjusts gh_role definition (currently introduces mutable-default risk).
docs/source/conf.py Minor doc config formatting / noqa cleanup.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread docs/source/whats_new.rst Outdated
Comment thread docs/source/sphinxext/gh_substitutions.py Outdated
Comment thread src/gromo/utils/optimisation.py Outdated
Comment on lines +761 to +764
update_function=lambda batch_tensor, batch_weight: (
(batch_tensor * batch_weight.unsqueeze(1)).sum(dim=0),
float(batch_weight.sum().item()),
),
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These weighted-statistic update functions return float(batch_weight.sum().item()) as the sample count/mass. However TensorStatistic/TensorStatisticWithEstimationError currently types and stores samples as an int and expects nb_sample to be an integer count. Either change the statistics helpers to support a float-valued sample mass (and type it accordingly) or return an int count and track weight mass separately; otherwise typing (Pyright/Pylance) and the semantics of samples become inconsistent.

Copilot uses AI. Check for mistakes.
Comment on lines +9 to +10
StatisticUpdateResult: TypeAlias = tuple[torch.Tensor, int]
StatisticUpdateFunction: TypeAlias = Callable[..., StatisticUpdateResult]
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StatisticUpdateResult fixes the returned sample count to int, but the new closed-form optimizer uses TensorStatisticWithEstimationError to accumulate weighted moments where the natural normalization is a float-valued sample mass. To avoid type-checker errors and clarify semantics, consider changing this alias (and TensorStatistic.samples) to accept float (or float | int).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Comment on lines +903 to +907
def optimize(
self,
block: nn.Module | TwoLayerLinearBlockView,
dataloader: list[Any] | Any,
) -> OptimizationResult:
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ClosedFormBlockOptimizer.optimize iterates over dataloader multiple times (once per alternating step, plus a final pass). If callers pass an iterator/generator (single-pass), subsequent iterations will silently see no batches and raise/produce incorrect updates. Consider requiring a re-iterable Iterable/DataLoader (and document it), or materialize/copy batches when a one-shot iterator is provided.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Edarfix and others added 3 commits March 31, 2026 17:07
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@sylvchev sylvchev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An important PR, we could iterate in the next push to update the code if needed, but it seems mature enough. On top of the tests, it could bé interesting to add an example demonstrating how to use the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants