Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

HeavyBall is an optimizer library for PyTorch where every optimizer is assembled from composable, compiled building
blocks. It includes API-compatible replacements for `torch.optim.AdamW`, `SGD`, and `RMSprop`, alongside Muon, SOAP (
Shampoo), PSGD (Kronecker), ADOPT, Schedule-Free, LaProp, and others.
Shampoo), PSGD (Kronecker), LATHER, ADOPT, Schedule-Free, LaProp, and others.

The building blocks, over 100 functions in [`utils.py`](heavyball/utils.py), are each compiled with
`torch.compile(fullgraph=True)` and fuse into Triton kernels. Features like MARS gradient correction,
Expand All @@ -21,21 +21,31 @@ Requires PyTorch >= 2.2.

```python
from heavyball import AdamW

opt = AdamW(model.parameters(), lr=1e-3)
```

```python
from heavyball import SOAP # Shampoo-based preconditioning

opt = SOAP(model.parameters(), lr=3e-3)
```

```python
from heavyball import LATHER # Lie-group Adam Through Harmonic Eigenbasis Rotations

opt = LATHER(model.parameters(), lr=1e-3)
```

```python
from heavyball import Muon

opt = Muon(model.parameters(), lr=0.02, ecc="bf16+8", mars=True, caution=True)
```

```python
from heavyball import SplitOpt, Muon, AdamW

opt = SplitOpt([
{'params': matrices, 'optimizer': Muon, 'lr': 0.02},
{'params': vectors, 'optimizer': AdamW, 'lr': 1e-3},
Expand All @@ -44,6 +54,8 @@ opt = SplitOpt([

The API matches `torch.optim`, with the same parameter groups, same `step()`/`zero_grad()` interface. See [
`examples/`](examples/) for training scripts.
By default, HeavyBall consumes gradients during `step()` and clears `p.grad` once it has used it. Pass
`consume_grad=False` if your training loop needs gradients to remain attached after the optimizer step.

## Optimizers

Expand All @@ -55,29 +67,25 @@ training, and SAM.
<summary>Full list</summary>

**First-order:**
AdamW, NAdam, RMSprop, ADOPT, ForeachAdEMAMix, LaProp, SignLaProp, SGD, Scion, UnscaledAdamW, ForeachAdamC, SUDSAdamW
AdamW, NAdam, RMSprop, ADOPT, AdEMAMix, LaProp, SignLaProp, SGD, Scion, UnscaledAdamW, AdamC, SUDSAdamW

**Schedule-Free:**
SFAdamW, PaLMSFAdamW
SFAdamW

Schedule-Free optimizers override `.eval()` and `.train()` to swap between training and evaluation parameter states.
Call `opt.eval()` before validation and `opt.train()` before resuming training.

**Orthogonal:**
Muon, MuonLaProp, OrthoLaProp, LaPropOrtho
Muon, MuonAdamW, MuonLaProp, HyperBallAdamW, OrthoLaProp, LaPropOrtho

**Shampoo-based (SOAP):**
SOAP, PaLMSOAP, PrecondScheduleSOAP, PrecondSchedulePaLMSOAP, SOAPNAdam, SOAPAdEMAMix, ForeachSOLP
SOAP, SOAPNAdam, SOAPAdEMAMix, SOLP

**PSGD (Kronecker):**
PSGDKron, CachedPSGDKron, DelayedPSGD, CachedDelayedPSGDKron, PurePSGD, NewtonPSGDKron, NewtonHybrid2PSGDKron

`Newton`-PSGD requires a closure passed to `step()`.
PSGDKron, LATHER, PSGDPRO

**PSGD (Low-Rank):**
PSGDLRA, DelayedPSGDLRA, NewtonPSGDLRA, NewtonHybrid2PSGDLRA

`Newton`-PSGD requires a closure passed to `step()`.
PSGDLRA

**SAM:**
SAMWrapper, MSAMLaProp
Expand Down Expand Up @@ -132,9 +140,10 @@ Available modes: `bf16+8`, `bf16+16`, `fp16+8`, `fp16+16`.

HeavyBall works with both DDP and FSDP. First-order optimizers are elementwise and operate directly on FSDP shards with
no repartitioning. Second-order methods (Muon, SOAP, PSGD) need the full parameter to compute their update, so HeavyBall
auto-detects FSDP-sharded parameters on the first step and repartitions them: each weight matrix is assigned to one rank
in round-robin, which reconstructs the full parameter, computes the update, and broadcasts the result. This saves both
compute and memory compared to DDP-style redundant updates, at the cost of communication.
auto-detects FSDP-sharded parameters on the first step and repartitions them with a metadata-first `all_to_all_single`
exchange: each weight matrix is deterministically assigned to one rank, shard metadata is exchanged up front, the owner
reconstructs the full parameter, computes the update once, and returns the updated shards. This saves both compute and
memory compared to DDP-style redundant updates, at the cost of communication.

```python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand All @@ -157,23 +166,25 @@ opt = SOAP(model.parameters(), lr=3e-3, orig_shapes=shapes)
## Building Custom Optimizers

Every built-in optimizer is a chain of `FunctionTransform`s, an API also available for building custom optimizers.
`Branch` runs parallel transform paths with a merge function, which is useful for grafted optimizers or ensemble
`Parallel` runs parallel transform paths with a merge function, which is useful for grafted optimizers or ensemble
updates.

```python
import heavyball.chainable as C


def graft(outputs, eps=1e-8):
adam_update, sgd_update = outputs
return [s * (a.norm() / s.norm().add(eps)) for a, s in zip(adam_update, sgd_update)]


class GraftedAdam(C.BaseOpt):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, warmup_steps=0, foreach=True):
weight_decay=0, warmup_steps=0, multi_tensor=True):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
warmup_steps=warmup_steps)
branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft)
super().__init__(params, defaults, foreach, fns=(branch,))
branch = C.Parallel(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft)
super().__init__(params, defaults, multi_tensor, fns=(branch,))
```

Custom optimizers that inherit from `BaseOpt` get ECC, MARS, caution, clipping, warmup, and stochastic rounding
Expand Down Expand Up @@ -204,14 +215,19 @@ Custom optimizers built via the chainable API inherit this behavior.

## Benchmarks

HeavyBall includes a diagnostic benchmark suite via [LightBench](https://github.com/HomebrewML/LightBench) that tests
HeavyBall includes a benchmark suite via [LightBench](https://github.com/HomebrewML/LightBench) that tests
for silent optimizer failures across difficulty levels. Results and methodology are documented
in [docs/benchmark.md](docs/benchmark.md).

## Migrating from 1.x
[`benchmarks/bench_release_optimizers.py`](benchmarks/bench_optimizer_step.py) measures optimizer latency, with
AdamW step times dropping from 10.63 ms in HeavyBall 2 to 4.15 ms in HeavyBall 3.

## Migrating

**From 2.x** See the [3.0.0 migration guide](docs/heavyball3.md) for renamed classes, removed kwargs, and checkpoint
conversion.

See the [2.0.0 migration notes](docs/heavyball2.md) for a full checklist, and `scripts/migrate_optimizer_state.py` for
checkpoint conversion.
**From 1.x** See the [2.0.0 migration notes](docs/heavyball2.md), then follow the 3.0.0 guide.

## Contributing

Expand Down
Binary file modified assets/benchmark_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
85 changes: 85 additions & 0 deletions benchmarks/bench_optimizer_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from enum import StrEnum
from math import prod
from time import perf_counter

import numpy as np
import torch
import typer

import heavyball

app = typer.Typer(add_completion=False, pretty_exceptions_enable=False)

DEFAULT_SHAPES = ((2048, 2048),) * 32


class DType(StrEnum):
float16 = "float16"
bfloat16 = "bfloat16"
float32 = "float32"


class Library(StrEnum):
heavyball = "heavyball"
torch = "torch"


def parse_shape(text: str) -> tuple[int, ...]:
try:
shape = tuple(map(int, text.lower().replace("x", " ").split()))
except ValueError as e:
raise typer.BadParameter(f"invalid shape: {text!r}") from e
if not shape:
raise typer.BadParameter(f"invalid shape: {text!r}")
return shape


@app.command()
def main(
optimizer: str = "AdamW",
library: Library = Library.heavyball,
dtype: DType = DType.float32,
shape: list[str] | None = None,
compile_step: bool = False,
fused: bool | None = None,
update_precond: bool | None = None,
steps: int = 300,
warmup: int = 20,
windows: int = 6,
seed: int = 0,
):
shapes = DEFAULT_SHAPES if shape is None else tuple(map(parse_shape, shape))
torch_dtype = getattr(torch, dtype)
kwargs = {"compile_step": compile_step} if library is Library.heavyball else {}
if fused is not None and library is Library.torch:
kwargs["fused"] = fused
if update_precond is not None and library is Library.heavyball:
kwargs["preconditioner_update_probability"] = float(update_precond)

gen = torch.Generator(device="cuda").manual_seed(seed)
params = []
for dims in shapes:
param = torch.nn.Parameter(torch.randn(dims, device="cuda", dtype=torch_dtype, generator=gen))
param.grad = torch.randn(dims, device="cuda", dtype=torch_dtype, generator=gen)
params.append(param)

module = heavyball if library is Library.heavyball else torch.optim
step = getattr(module, optimizer)(params, **kwargs).step
for _ in range(warmup):
step()

times = []
for _ in range(windows):
torch.cuda.synchronize()
start = perf_counter()
for _ in range(steps):
step()
torch.cuda.synchronize()
times.append((perf_counter() - start) / steps)

print(f"{len(shapes)} tensors, {sum(prod(s) for s in shapes)} total params")
print(f"Median Time: {np.median(times) * 1e6:.3f}µs")


if __name__ == "__main__":
app()
2 changes: 1 addition & 1 deletion benchmarks/bench_singular_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def key_fn(r):
f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {min(rerrs):>10.6f} {max(rerrs):>10.6f} {errs:>6} {len(items):>5}"
)
else:
print(f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {'':>10} {'':>10} {errs:>6} {len(items):>5}")
print(f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {'-':>10} {'-':>10} {errs:>6} {len(items):>5}")


def main():
Expand Down
8 changes: 4 additions & 4 deletions docs/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ reinforcing the need for diagnostic rather than purely comparative evaluation.
| Optimizer | Cautious¹ | Mars² | Success | Attempts | Avg Runtime (s) |
|:---------------|:----------|:------|:--------|:---------|:----------------|
| PSGDKron | No | No | 77.0% | 73.2 | 8240 |
| NewtonPSGDKron | No | No | 77.0% | 80.5 | 9052 |
| PSGDKron (Newton) | No | No | 77.0% | 80.5 | 9052 |
| AdamW | Yes | No | 75.7% | 61.2 | 8072 |
| ForeachSOAP | No | No | 72.5% | 77.9 | 7827 |
| SOAP | No | No | 72.5% | 77.9 | 7827 |
| AdamW | No | No | 72.3% | 107.8 | 10029 |
| MuonLaProp | No | No | 68.2% | 82.7 | 10141 |
| RMSprop | No | No | 55.6% | 114.4 | 10725 |
Expand All @@ -82,7 +82,7 @@ informed choice.
### Case Study: Escaping the Saddle Point

An optimizer’s inability to navigate a saddle point is a classic example of a silent failure. A key test of an
optimizer's robustness is its ability to navigate a saddle pointa region that is a minimum in one direction but a
optimizer's robustness is its ability to navigate a saddle point - a region that is a minimum in one direction but a
maximum in another. The gradient approaches zero at the center, trapping first-order methods that rely solely on the
gradient.

Expand All @@ -95,7 +95,7 @@ optimizer may be unreliable in these settings.
## Conclusion

The HeavyBall Benchmark represents a necessary shift in how we evaluate optimizers, moving from a culture of
score-chasing to one of deep, diagnostic understanding. These hidden failures aren’t rare edge casesthey’re a routine
score-chasing to one of deep, diagnostic understanding. These hidden failures aren’t rare edge cases - they’re a routine
source of wasted compute and disappointing models. By making them explicit, the benchmark equips researchers and
practitioners with a detailed map of an optimizer's capabilities. By clearly identifying hidden failure modes,
practitioners can confidently choose, tune, or reconsider their optimization strategies, ultimately leading to more
Expand Down
4 changes: 2 additions & 2 deletions docs/heavyball2.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

* First‑class SAM via `SAMWrapper` (closure‑based)
* More robust checkpoint/restore with HeavyBall‑internal state
* New optimizers: `SGD`, `ForeachAdamC`, `MSAMLaProp`
* New optimizers: `SGD`, `AdamC`, `MSAMLaProp`
* Overhauled chainable pipeline: indexed transforms, branching, internal gradient‑accumulation, and `SqueezeGrad`
* Faster, more accurate code paths
* New `heavyball.helpers` with Optuna‑compatible samplers and utilities
Expand All @@ -18,7 +18,7 @@
* `SAMWrapper` applies sharpness‑aware minimization to any HeavyBall optimizer while preserving the wrapped step logic;
requires a closure
* `SGD` built on the chainable internals
* `ForeachAdamC`, a ["corrected version of Adam"](https://arxiv.org/abs/2506.02285) with weight decay normalized by the
* `AdamC`, a ["corrected version of Adam"](https://arxiv.org/abs/2506.02285) with weight decay normalized by the
maximum LR
* `MSAMLaProp` built on top of [Momentum‑SAM](https://arxiv.org/abs/2401.12033)
* Chainable pipeline:
Expand Down
Loading
Loading