diff --git a/README.md b/README.md index e45eefe..5c15e1a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ HeavyBall is an optimizer library for PyTorch where every optimizer is assembled from composable, compiled building blocks. It includes API-compatible replacements for `torch.optim.AdamW`, `SGD`, and `RMSprop`, alongside Muon, SOAP ( -Shampoo), PSGD (Kronecker), ADOPT, Schedule-Free, LaProp, and others. +Shampoo), PSGD (Kronecker), LATHER, ADOPT, Schedule-Free, LaProp, and others. The building blocks, over 100 functions in [`utils.py`](heavyball/utils.py), are each compiled with `torch.compile(fullgraph=True)` and fuse into Triton kernels. Features like MARS gradient correction, @@ -21,21 +21,31 @@ Requires PyTorch >= 2.2. ```python from heavyball import AdamW + opt = AdamW(model.parameters(), lr=1e-3) ``` ```python from heavyball import SOAP # Shampoo-based preconditioning + opt = SOAP(model.parameters(), lr=3e-3) ``` +```python +from heavyball import LATHER # Lie-group Adam Through Harmonic Eigenbasis Rotations + +opt = LATHER(model.parameters(), lr=1e-3) +``` + ```python from heavyball import Muon + opt = Muon(model.parameters(), lr=0.02, ecc="bf16+8", mars=True, caution=True) ``` ```python from heavyball import SplitOpt, Muon, AdamW + opt = SplitOpt([ {'params': matrices, 'optimizer': Muon, 'lr': 0.02}, {'params': vectors, 'optimizer': AdamW, 'lr': 1e-3}, @@ -44,6 +54,8 @@ opt = SplitOpt([ The API matches `torch.optim`, with the same parameter groups, same `step()`/`zero_grad()` interface. See [ `examples/`](examples/) for training scripts. +By default, HeavyBall consumes gradients during `step()` and clears `p.grad` once it has used it. Pass +`consume_grad=False` if your training loop needs gradients to remain attached after the optimizer step. ## Optimizers @@ -55,29 +67,25 @@ training, and SAM. Full list **First-order:** -AdamW, NAdam, RMSprop, ADOPT, ForeachAdEMAMix, LaProp, SignLaProp, SGD, Scion, UnscaledAdamW, ForeachAdamC, SUDSAdamW +AdamW, NAdam, RMSprop, ADOPT, AdEMAMix, LaProp, SignLaProp, SGD, Scion, UnscaledAdamW, AdamC, SUDSAdamW **Schedule-Free:** -SFAdamW, PaLMSFAdamW +SFAdamW Schedule-Free optimizers override `.eval()` and `.train()` to swap between training and evaluation parameter states. Call `opt.eval()` before validation and `opt.train()` before resuming training. **Orthogonal:** -Muon, MuonLaProp, OrthoLaProp, LaPropOrtho +Muon, MuonAdamW, MuonLaProp, HyperBallAdamW, OrthoLaProp, LaPropOrtho **Shampoo-based (SOAP):** -SOAP, PaLMSOAP, PrecondScheduleSOAP, PrecondSchedulePaLMSOAP, SOAPNAdam, SOAPAdEMAMix, ForeachSOLP +SOAP, SOAPNAdam, SOAPAdEMAMix, SOLP **PSGD (Kronecker):** -PSGDKron, CachedPSGDKron, DelayedPSGD, CachedDelayedPSGDKron, PurePSGD, NewtonPSGDKron, NewtonHybrid2PSGDKron - -`Newton`-PSGD requires a closure passed to `step()`. +PSGDKron, LATHER, PSGDPRO **PSGD (Low-Rank):** -PSGDLRA, DelayedPSGDLRA, NewtonPSGDLRA, NewtonHybrid2PSGDLRA - -`Newton`-PSGD requires a closure passed to `step()`. +PSGDLRA **SAM:** SAMWrapper, MSAMLaProp @@ -132,9 +140,10 @@ Available modes: `bf16+8`, `bf16+16`, `fp16+8`, `fp16+16`. HeavyBall works with both DDP and FSDP. First-order optimizers are elementwise and operate directly on FSDP shards with no repartitioning. Second-order methods (Muon, SOAP, PSGD) need the full parameter to compute their update, so HeavyBall -auto-detects FSDP-sharded parameters on the first step and repartitions them: each weight matrix is assigned to one rank -in round-robin, which reconstructs the full parameter, computes the update, and broadcasts the result. This saves both -compute and memory compared to DDP-style redundant updates, at the cost of communication. +auto-detects FSDP-sharded parameters on the first step and repartitions them with a metadata-first `all_to_all_single` +exchange: each weight matrix is deterministically assigned to one rank, shard metadata is exchanged up front, the owner +reconstructs the full parameter, computes the update once, and returns the updated shards. This saves both compute and +memory compared to DDP-style redundant updates, at the cost of communication. ```python from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -157,23 +166,25 @@ opt = SOAP(model.parameters(), lr=3e-3, orig_shapes=shapes) ## Building Custom Optimizers Every built-in optimizer is a chain of `FunctionTransform`s, an API also available for building custom optimizers. -`Branch` runs parallel transform paths with a merge function, which is useful for grafted optimizers or ensemble +`Parallel` runs parallel transform paths with a merge function, which is useful for grafted optimizers or ensemble updates. ```python import heavyball.chainable as C + def graft(outputs, eps=1e-8): adam_update, sgd_update = outputs return [s * (a.norm() / s.norm().add(eps)) for a, s in zip(adam_update, sgd_update)] + class GraftedAdam(C.BaseOpt): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, warmup_steps=0, foreach=True): + weight_decay=0, warmup_steps=0, multi_tensor=True): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, warmup_steps=warmup_steps) - branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft) - super().__init__(params, defaults, foreach, fns=(branch,)) + branch = C.Parallel(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft) + super().__init__(params, defaults, multi_tensor, fns=(branch,)) ``` Custom optimizers that inherit from `BaseOpt` get ECC, MARS, caution, clipping, warmup, and stochastic rounding @@ -204,14 +215,19 @@ Custom optimizers built via the chainable API inherit this behavior. ## Benchmarks -HeavyBall includes a diagnostic benchmark suite via [LightBench](https://github.com/HomebrewML/LightBench) that tests +HeavyBall includes a benchmark suite via [LightBench](https://github.com/HomebrewML/LightBench) that tests for silent optimizer failures across difficulty levels. Results and methodology are documented in [docs/benchmark.md](docs/benchmark.md). -## Migrating from 1.x +[`benchmarks/bench_release_optimizers.py`](benchmarks/bench_optimizer_step.py) measures optimizer latency, with +AdamW step times dropping from 10.63 ms in HeavyBall 2 to 4.15 ms in HeavyBall 3. + +## Migrating + +**From 2.x** See the [3.0.0 migration guide](docs/heavyball3.md) for renamed classes, removed kwargs, and checkpoint +conversion. -See the [2.0.0 migration notes](docs/heavyball2.md) for a full checklist, and `scripts/migrate_optimizer_state.py` for -checkpoint conversion. +**From 1.x** See the [2.0.0 migration notes](docs/heavyball2.md), then follow the 3.0.0 guide. ## Contributing diff --git a/assets/benchmark_matrix.png b/assets/benchmark_matrix.png index 8300fca..26a7839 100644 Binary files a/assets/benchmark_matrix.png and b/assets/benchmark_matrix.png differ diff --git a/benchmarks/bench_optimizer_step.py b/benchmarks/bench_optimizer_step.py new file mode 100644 index 0000000..62ced5e --- /dev/null +++ b/benchmarks/bench_optimizer_step.py @@ -0,0 +1,85 @@ +from enum import StrEnum +from math import prod +from time import perf_counter + +import numpy as np +import torch +import typer + +import heavyball + +app = typer.Typer(add_completion=False, pretty_exceptions_enable=False) + +DEFAULT_SHAPES = ((2048, 2048),) * 32 + + +class DType(StrEnum): + float16 = "float16" + bfloat16 = "bfloat16" + float32 = "float32" + + +class Library(StrEnum): + heavyball = "heavyball" + torch = "torch" + + +def parse_shape(text: str) -> tuple[int, ...]: + try: + shape = tuple(map(int, text.lower().replace("x", " ").split())) + except ValueError as e: + raise typer.BadParameter(f"invalid shape: {text!r}") from e + if not shape: + raise typer.BadParameter(f"invalid shape: {text!r}") + return shape + + +@app.command() +def main( + optimizer: str = "AdamW", + library: Library = Library.heavyball, + dtype: DType = DType.float32, + shape: list[str] | None = None, + compile_step: bool = False, + fused: bool | None = None, + update_precond: bool | None = None, + steps: int = 300, + warmup: int = 20, + windows: int = 6, + seed: int = 0, +): + shapes = DEFAULT_SHAPES if shape is None else tuple(map(parse_shape, shape)) + torch_dtype = getattr(torch, dtype) + kwargs = {"compile_step": compile_step} if library is Library.heavyball else {} + if fused is not None and library is Library.torch: + kwargs["fused"] = fused + if update_precond is not None and library is Library.heavyball: + kwargs["preconditioner_update_probability"] = float(update_precond) + + gen = torch.Generator(device="cuda").manual_seed(seed) + params = [] + for dims in shapes: + param = torch.nn.Parameter(torch.randn(dims, device="cuda", dtype=torch_dtype, generator=gen)) + param.grad = torch.randn(dims, device="cuda", dtype=torch_dtype, generator=gen) + params.append(param) + + module = heavyball if library is Library.heavyball else torch.optim + step = getattr(module, optimizer)(params, **kwargs).step + for _ in range(warmup): + step() + + times = [] + for _ in range(windows): + torch.cuda.synchronize() + start = perf_counter() + for _ in range(steps): + step() + torch.cuda.synchronize() + times.append((perf_counter() - start) / steps) + + print(f"{len(shapes)} tensors, {sum(prod(s) for s in shapes)} total params") + print(f"Median Time: {np.median(times) * 1e6:.3f}µs") + + +if __name__ == "__main__": + app() diff --git a/benchmarks/bench_singular_values.py b/benchmarks/bench_singular_values.py index c324b64..f95f735 100644 --- a/benchmarks/bench_singular_values.py +++ b/benchmarks/bench_singular_values.py @@ -113,7 +113,7 @@ def key_fn(r): f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {min(rerrs):>10.6f} {max(rerrs):>10.6f} {errs:>6} {len(items):>5}" ) else: - print(f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {'—':>10} {'—':>10} {errs:>6} {len(items):>5}") + print(f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {'-':>10} {'-':>10} {errs:>6} {len(items):>5}") def main(): diff --git a/docs/benchmark.md b/docs/benchmark.md index 398b051..f6ff709 100644 --- a/docs/benchmark.md +++ b/docs/benchmark.md @@ -56,9 +56,9 @@ reinforcing the need for diagnostic rather than purely comparative evaluation. | Optimizer | Cautious¹ | Mars² | Success | Attempts | Avg Runtime (s) | |:---------------|:----------|:------|:--------|:---------|:----------------| | PSGDKron | No | No | 77.0% | 73.2 | 8240 | -| NewtonPSGDKron | No | No | 77.0% | 80.5 | 9052 | +| PSGDKron (Newton) | No | No | 77.0% | 80.5 | 9052 | | AdamW | Yes | No | 75.7% | 61.2 | 8072 | -| ForeachSOAP | No | No | 72.5% | 77.9 | 7827 | +| SOAP | No | No | 72.5% | 77.9 | 7827 | | AdamW | No | No | 72.3% | 107.8 | 10029 | | MuonLaProp | No | No | 68.2% | 82.7 | 10141 | | RMSprop | No | No | 55.6% | 114.4 | 10725 | @@ -82,7 +82,7 @@ informed choice. ### Case Study: Escaping the Saddle Point An optimizer’s inability to navigate a saddle point is a classic example of a silent failure. A key test of an -optimizer's robustness is its ability to navigate a saddle point—a region that is a minimum in one direction but a +optimizer's robustness is its ability to navigate a saddle point - a region that is a minimum in one direction but a maximum in another. The gradient approaches zero at the center, trapping first-order methods that rely solely on the gradient. @@ -95,7 +95,7 @@ optimizer may be unreliable in these settings. ## Conclusion The HeavyBall Benchmark represents a necessary shift in how we evaluate optimizers, moving from a culture of -score-chasing to one of deep, diagnostic understanding. These hidden failures aren’t rare edge cases—they’re a routine +score-chasing to one of deep, diagnostic understanding. These hidden failures aren’t rare edge cases - they’re a routine source of wasted compute and disappointing models. By making them explicit, the benchmark equips researchers and practitioners with a detailed map of an optimizer's capabilities. By clearly identifying hidden failure modes, practitioners can confidently choose, tune, or reconsider their optimization strategies, ultimately leading to more diff --git a/docs/heavyball2.md b/docs/heavyball2.md index 5660cb1..7bc8555 100644 --- a/docs/heavyball2.md +++ b/docs/heavyball2.md @@ -4,7 +4,7 @@ * First‑class SAM via `SAMWrapper` (closure‑based) * More robust checkpoint/restore with HeavyBall‑internal state -* New optimizers: `SGD`, `ForeachAdamC`, `MSAMLaProp` +* New optimizers: `SGD`, `AdamC`, `MSAMLaProp` * Overhauled chainable pipeline: indexed transforms, branching, internal gradient‑accumulation, and `SqueezeGrad` * Faster, more accurate code paths * New `heavyball.helpers` with Optuna‑compatible samplers and utilities @@ -18,7 +18,7 @@ * `SAMWrapper` applies sharpness‑aware minimization to any HeavyBall optimizer while preserving the wrapped step logic; requires a closure * `SGD` built on the chainable internals -* `ForeachAdamC`, a ["corrected version of Adam"](https://arxiv.org/abs/2506.02285) with weight decay normalized by the +* `AdamC`, a ["corrected version of Adam"](https://arxiv.org/abs/2506.02285) with weight decay normalized by the maximum LR * `MSAMLaProp` built on top of [Momentum‑SAM](https://arxiv.org/abs/2401.12033) * Chainable pipeline: diff --git a/docs/heavyball3.md b/docs/heavyball3.md new file mode 100644 index 0000000..592f381 --- /dev/null +++ b/docs/heavyball3.md @@ -0,0 +1,146 @@ +# HeavyBall 3.0.0 + +## Highlights + +* Simplified public API: `Foreach*` prefixes removed, short names are now the canonical classes +* New optimizers: `HyperBallAdamW`, `MuonAdamW`, `LATHER`, `PSGDPRO` +* `LATHER` expands to "Lie-group Adam Through Harmonic Eigenbasis Rotations" +* `Route`-based param dispatch replaces manual `SplitOpt` for mixed-architecture optimizers +* `ScheduleFree` and `MSAM` mode switches are now idempotent (`eval()` twice is safe) +* Higher-precision PSGD preconditioner updates +* New `consume_grad` option: `step()` clears `p.grad` after consuming it by default; set `consume_grad=False` to keep gradients attached after the step +* `orig_shapes` is now an explicit documented optimizer argument; use `capture_param_shapes(...)` before wrapping models with sharding backends that do not preserve original parameter shapes +* `torch.compile`-friendly step with automatic eager fallback for init/preconditioning + +--- + +## Release benchmarks + +HeavyBall 3.0.0 was benchmarked against HeavyBall 2.0.0 and `torch.optim` with +[`benchmarks/bench_release_optimizers.py`](../benchmarks/bench_release_optimizers.py), with compiled AdamW step latency +dropping from 10.63 ms in HeavyBall 2.0.0 to 4.15 ms in HeavyBall 3.0.0, a 2.56x speedup. + +## Breaking changes + +### Class renames + +Every `Foreach*` class is renamed to its short form. The old short-form aliases (which existed +in 2.x) keep working — only the `Foreach*` imports break. + +| 2.x name | 3.x name | +|---|---| +| `ForeachAdamW` | `AdamW` | +| `ForeachNAdam` | `NAdam` | +| `ForeachAdEMAMix` | `AdEMAMix` | +| `ForeachAdamC` | `AdamC` | +| `ForeachRMSprop` | `RMSprop` | +| `ForeachSFAdamW` | `SFAdamW` | +| `ForeachADOPT` | `ADOPT` | +| `ForeachMuon` | `Muon` | +| `ForeachLaProp` | `LaProp` | +| `ForeachSignLaProp` | `SignLaProp` | +| `ForeachSOAP` | `SOAP` | +| `ForeachSOAPNAdam` | `SOAPNAdam` | +| `ForeachSOAPAdEMAMix` | `SOAPAdEMAMix` | +| `ForeachSOLP` | `SOLP` | +| `ForeachPSGDKron` | `PSGDKron` | +| `ForeachPSGDLRA` | `PSGDLRA` | + +### Removed optimizer classes + +These were thin subclasses that only set a class-level default. Use the parent class with the +corresponding constructor argument instead. + +| 2.x class | 3.x equivalent | +|---|---| +| `PaLMForeachSFAdamW` / `PaLMSFAdamW` | `SFAdamW(..., palm=True)` | +| `PaLMForeachSOAP` / `PaLMSOAP` / `PalmForEachSoap` | `SOAP(..., palm=True)` | +| `PrecondScheduleForeachSOAP` / `PrecondScheduleSOAP` | `SOAP(..., use_precond_schedule=True)` | +| `PrecondSchedulePaLMForeachSOAP` / `PrecondSchedulePaLMSOAP` | `SOAP(..., palm=True, use_precond_schedule=True)` | +| `ForeachPurePSGD` / `PurePSGD` | `PSGDKron(..., exp_avg_input=False)` | +| `ForeachCachedPSGDKron` / `CachedPSGDKron` | `PSGDKron(...)` (caching is now the default) | +| `ForeachDelayedPSGD` / `DelayedPSGD` | `PSGDKron(..., delayed=True)` | +| `ForeachCachedDelayedPSGDKron` / `CachedDelayedPSGDKron` | `PSGDKron(..., delayed=True)` | +| `ForeachCachedNewtonPSGD` / `NewtonPSGDKron` | `PSGDKron(..., hessian_approx=True)` | +| `NewtonHybrid2PSGDKron` | `PSGDKron(..., hessian_approx=True, hvp_interval=2)` | +| `ForeachDelayedPSGDLRA` / `DelayedPSGDLRA` | `PSGDLRA(..., delayed=True)` | +| `ForeachNewtonPSGDLRA` / `NewtonPSGDLRA` | `PSGDLRA(..., hessian_approx=True)` | +| `NewtonHybrid2PSGDLRA` | `PSGDLRA(..., hessian_approx=True, hvp_interval=2)` | + +### Renamed parameters + +| 2.x parameter | 3.x parameter | Notes | +|---|---|---| +| `foreach` | `multi_tensor` | Passing `foreach` emits a `FutureWarning` and remaps automatically | + +### Removed parameters + +These raise `TypeError` if passed. They were either unused or replaced by better defaults. + +| Parameter | Previously on | Notes | +|---|---|----------------------------------------------------------| +| `stochastic_schedule` | SOAP, PSGDKron, PSGDLRA | Deterministic accumulation schedule is now the only mode | +| `normalize_grads` | SOAP variants | Was unused in the transform pipeline | +| `correct_bias` | SOAP variants | Was unused in the transform pipeline | +| `inverse_free` | PSGDKron | Use `quad_torch` or PSGDPRO for inverse-free PSGD | +| `adaptive` | PSGDKron | Removed | + +### Helper sampler kwargs + +These compatibility kwargs were removed from `heavyball.helpers` samplers and now raise +`TypeError`. + +| Class | Removed kwargs | +|---|---| +| `BoTorchSampler` | `constraints_func`, `consider_running_trials` | +| `HEBOSampler` | `constant_liar` | +| `ImplicitNaturalGradientSampler` | `lr`, `warn_independent_sampling` | +| `AutoSampler` | `constraints_func` | + +### Chainable API renames + +| 2.x name | 3.x name | +|---|---| +| `Branch` | `Parallel` | + +### Behavioral changes + +* **ScheduleFree / MSAM `eval()` / `train()`**: Now idempotent. Calling `eval()` twice no + longer flips back to train mode. Both methods accept a `mode` argument matching + `nn.Module.train(mode)` and return `self`. +* **Gradient lifetime**: `consume_grad=True` is available on all optimizers and clears `p.grad` + during `step()` once the gradient has been consumed. Set `consume_grad=False` if your code + reads gradients after stepping or relies on them remaining attached. +* **Sharded parameter shapes**: Built-in optimizers now expose `orig_shapes` explicitly. Use + `capture_param_shapes()` before wrapping parameters if your sharding backend hides original + shapes. +* **PSGD dampening**: `dampen_grad` default changed from `2**-13` to `1e-9`, and dampening + epsilon uses `torch.finfo(float32).eps` regardless of input dtype. This improves + preconditioner accuracy but may change convergence behavior. + +--- + +## Checkpoint migration + +Use the migration CLI to convert 1.x or 2.x checkpoints: + +```bash +python scripts/migrate_optimizer_state.py +``` + +Old class names (including all aliases listed above) are resolved automatically. +The `foreach` → `multi_tensor` key rename in param groups is handled automatically. + +--- + +## Upgrade checklist + +1. Replace `from heavyball import Foreach*` with the short name (e.g., `ForeachAdamW` → `AdamW`) +2. Replace `foreach=` with `multi_tensor=` in constructor calls +3. Replace removed subclass instantiations with parent + kwargs (see table above) +4. Remove any `stochastic_schedule`, `normalize_grads`, `correct_bias`, `inverse_free`, or `adaptive` kwargs +5. Replace `Branch(...)` with `Parallel(...)` in custom chainable code +6. Migrate checkpoints: `python scripts/migrate_optimizer_state.py heavyball.` +7. If you relied on `eval(); eval()` toggling back to train mode, update your code +8. If your training loop reads `p.grad` after `step()`, pass `consume_grad=False` +9. Remove obsolete compatibility kwargs from `heavyball.helpers` samplers diff --git a/examples/autoencoder.py b/examples/autoencoder.py index 20350cb..615ddba 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -96,7 +96,6 @@ def main(epochs: int, batch: int, log_interval: int = 16): lr=1e-4, mars=True, lower_bound_beta=0.9, - inverse_free=True, precond_update_power_iterations=6, store_triu_as_line=False, ) diff --git a/examples/branched_optimizer.py b/examples/branched_optimizer.py index b8a2795..73aa271 100644 --- a/examples/branched_optimizer.py +++ b/examples/branched_optimizer.py @@ -29,7 +29,7 @@ def __init__( eps: float = 1e-8, weight_decay: float = 1e-4, warmup_steps: int = 0, - foreach: bool = True, + multi_tensor: bool = True, ): defaults = dict( lr=lr, @@ -38,8 +38,8 @@ def __init__( weight_decay=weight_decay, warmup_steps=warmup_steps, ) - branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=_graft) - super().__init__(params, defaults, foreach, fns=(branch,)) + branch = C.Parallel(branches=[[C.scale_by_adam], [C.identity]], merge_fn=_graft) + super().__init__(params, defaults, multi_tensor, fns=(branch,)) def main(epochs: int = 20, batch_size: int = 256, subset_size: int = 4096): diff --git a/examples/ddp_training.py b/examples/ddp_training.py index dc27092..bc39932 100644 --- a/examples/ddp_training.py +++ b/examples/ddp_training.py @@ -2,8 +2,8 @@ Launch with torchrun: torchrun --nproc_per_node=2 examples/ddp_training.py - torchrun --nproc_per_node=2 examples/ddp_training.py --opt ForeachSOAP - torchrun --nproc_per_node=2 examples/ddp_training.py --opt ForeachMuon --lr 0.01 + torchrun --nproc_per_node=2 examples/ddp_training.py --opt SOAP + torchrun --nproc_per_node=2 examples/ddp_training.py --opt Muon --lr 0.01 All HeavyBall optimizers work transparently with DDP """ @@ -38,7 +38,7 @@ def make_model(): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--opt", default="ForeachAdamW") + parser.add_argument("--opt", default="AdamW") parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--lr", type=float, default=1e-3) diff --git a/examples/ecc_bf16.py b/examples/ecc_bf16.py index 1b8d1bf..7effd5d 100644 --- a/examples/ecc_bf16.py +++ b/examples/ecc_bf16.py @@ -19,11 +19,11 @@ CONFIGS = { "naive_fp32": lambda p: NaiveAdamW(p, lr=LR, betas=BETAS, eps=EPS, state_dtype=torch.float32), "naive_bf16": lambda p: NaiveAdamW(p, lr=LR, betas=BETAS, eps=EPS, state_dtype=torch.bfloat16), - "heavyball_fp32": lambda p: heavyball.ForeachAdamW(p, lr=LR, betas=BETAS, eps=EPS, storage_dtype="float32"), - "heavyball_bf16": lambda p: heavyball.ForeachAdamW( + "heavyball_fp32": lambda p: heavyball.AdamW(p, lr=LR, betas=BETAS, eps=EPS, storage_dtype="float32"), + "heavyball_bf16": lambda p: heavyball.AdamW( p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, storage_dtype="bfloat16" ), - "ecc_bf16+8": lambda p: heavyball.ForeachAdamW(p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, ecc="bf16+8"), + "ecc_bf16+8": lambda p: heavyball.AdamW(p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, ecc="bf16+8"), } COLORS = { diff --git a/examples/fsdp_training.py b/examples/fsdp_training.py index 375d0e8..06087de 100644 --- a/examples/fsdp_training.py +++ b/examples/fsdp_training.py @@ -2,8 +2,8 @@ Launch with torchrun: torchrun --nproc_per_node=2 examples/fsdp_training.py - torchrun --nproc_per_node=2 examples/fsdp_training.py --opt ForeachSOAP - torchrun --nproc_per_node=2 examples/fsdp_training.py --opt ForeachMuon --lr 0.01 + torchrun --nproc_per_node=2 examples/fsdp_training.py --opt SOAP + torchrun --nproc_per_node=2 examples/fsdp_training.py --opt Muon --lr 0.01 Shape-aware optimizers (SOAP, Muon, PSGD, Scion, etc.) auto-detect FSDP-flattened params and restore original shapes. No manual intervention needed. @@ -11,7 +11,7 @@ For non-FSDP parallelism backends, capture shapes before wrapping: shapes = heavyball.capture_param_shapes(model) model = your_wrapper(model) - opt = heavyball.ForeachSOAP(model.parameters(), lr=3e-3, orig_shapes=shapes) + opt = heavyball.SOAP(model.parameters(), lr=3e-3, orig_shapes=shapes) """ import argparse @@ -44,7 +44,7 @@ def make_model(): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--opt", default="ForeachAdamW") + parser.add_argument("--opt", default="AdamW") parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--lr", type=float, default=1e-3) diff --git a/heavyball/__init__.py b/heavyball/__init__.py index e8c3eaf..db48ee3 100644 --- a/heavyball/__init__.py +++ b/heavyball/__init__.py @@ -7,6 +7,8 @@ from . import chainable as C from . import utils +ShapeMap = dict[int, tuple[int, ...]] + class SGD(C.BaseOpt): def __init__( @@ -16,7 +18,7 @@ def __init__( beta=0.9, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -27,13 +29,14 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,)) -class ForeachAdamW(C.BaseOpt): +class AdamW(C.BaseOpt): def __init__( self, params, @@ -42,7 +45,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -55,13 +58,14 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,)) -class ForeachNAdam(C.BaseOpt): +class NAdam(C.BaseOpt): def __init__( self, params, @@ -72,7 +76,7 @@ def __init__( momentum_decay: float = 4e-3, decoupled_weight_decay: bool = False, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -85,13 +89,14 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,)) -class ForeachAdEMAMix(C.BaseOpt): +class AdEMAMix(C.BaseOpt): def __init__( self, params, @@ -103,7 +108,7 @@ def __init__( beta3_warmup: Optional[int] = None, alpha_warmup: Optional[int] = None, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -114,13 +119,14 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): if len(betas) != 3: raise ValueError("AdEMAMix expects betas with three coefficients.") params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,)) class UnscaledAdamW(C.BaseOpt): @@ -132,7 +138,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -145,12 +151,11 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) - super().__init__( - params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,) - ) + super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)) class SUDSAdamW(C.BaseOpt): @@ -162,7 +167,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -176,10 +181,11 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,)) class Scion(C.BaseOpt): @@ -191,7 +197,7 @@ def __init__( eps: float = 1e-8, weight_decay: float = 0, warmup_steps: int = 0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -204,6 +210,7 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): if lr < 0: @@ -221,12 +228,10 @@ def __init__( defaults["scale"] = scale defaults.pop("momentum", None) - super().__init__( - params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm) - ) + super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm)) -class ForeachAdamC(C.BaseOpt): +class AdamC(C.BaseOpt): def __init__( self, params, @@ -236,7 +241,7 @@ def __init__( weight_decay=0, max_lr: float | None = None, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -249,6 +254,7 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): if max_lr is None: @@ -258,10 +264,10 @@ def __init__( max_lr = lr params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,)) -class ForeachRMSprop(C.BaseOpt): +class RMSprop(C.BaseOpt): """ Debiased RMSprop (not torch.optim.RMSprop) """ @@ -276,7 +282,7 @@ def __init__( warmup_steps=0, r=0.0, weight_lr_power=2.0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -289,13 +295,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -303,7 +309,92 @@ def __init__( ) -class ForeachSFAdamW(C.ScheduleFree): +class HyperBallAdamW(C.BaseOpt): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + multi_tensor: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + compile_step: bool = C.use_default, + promote: bool = C.use_default, + ecc: str | None = None, + param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, + **kwargs, + ): + params, defaults = C._build_defaults(locals()) + super().__init__( + params, + defaults, + gradient_clipping, + update_clipping, + palm, + fns=( + C.scale_by_exp_avg_sq, + C.route( + (lambda p: p.ndim >= 2, C.update_by_hyperball), + default=C.apply_update, + ), + ), + ) + + +class MuonAdamW(C.BaseOpt): + def __init__( + self, + params, + lr=0.0025, + betas=(0.9, 0.99), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + multi_tensor: bool = True, + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + palm: bool = C.use_default, + beta2_scale: float = 0.8, + nesterov: bool = True, + compile_step: bool = C.use_default, + promote: bool = C.use_default, + ecc: str | None = None, + param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, + **kwargs, + ): + params, defaults = C._build_defaults(locals()) + ema = C.nesterov_ema if nesterov else C.exp_avg + super().__init__( + params, + defaults, + gradient_clipping, + update_clipping, + palm, + fns=( + C.route( + (lambda p: p.ndim >= 2, (ema, C.orthogonalize_update)), + default=C.scale_by_adam, + ), + ), + ) + + +class SFAdamW(C.ScheduleFree): def __init__( self, params, @@ -314,7 +405,7 @@ def __init__( warmup_steps=0, r=0.0, weight_lr_power=2.0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -327,13 +418,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -352,7 +443,7 @@ def __init__( warmup_steps=0, r=0.0, weight_lr_power=2.0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -366,13 +457,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -380,11 +471,7 @@ def __init__( ) -class PaLMForeachSFAdamW(ForeachSFAdamW): - palm: bool = True - - -class ForeachADOPT(C.BaseOpt): +class ADOPT(C.BaseOpt): def __init__( self, params, @@ -393,7 +480,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -406,13 +493,14 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,)) -class ForeachMuon(C.BaseOpt): +class Muon(C.BaseOpt): def __init__( self, params, @@ -421,7 +509,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -436,6 +524,7 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) @@ -450,7 +539,6 @@ def __init__( super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -458,7 +546,7 @@ def __init__( ) -class ForeachLaProp(C.BaseOpt): +class LaProp(C.BaseOpt): def __init__( self, params, @@ -467,7 +555,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -480,10 +568,11 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) - super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,)) + super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,)) class MuonLaProp(C.BaseOpt): @@ -495,7 +584,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -508,13 +597,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -522,9 +611,31 @@ def __init__( ) -class ForeachSOAP(C.BaseOpt): +class SOAPBase(C.BaseOpt): + use_precond_schedule: bool = False + + def _build_soap_defaults(self, locals_dict, fns): + use_precond_schedule = C.default(locals_dict["use_precond_schedule"], self.use_precond_schedule) + params, defaults = C._build_defaults(locals_dict) + if use_precond_schedule: + del defaults["precondition_frequency"] + self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler")) + else: + del defaults["precond_scheduler"] + self.precond_schedule = 1 / defaults.pop("precondition_frequency") + super().__init__( + params, + defaults, + locals_dict["gradient_clipping"], + locals_dict["update_clipping"], + locals_dict.get("palm", False), + fns=fns, + ) + + +class SOAP(SOAPBase): """ - ForeachSOAP + SOAP Sources: Baseline SOAP: @@ -534,8 +645,6 @@ class ForeachSOAP(C.BaseOpt): https://github.com/nikhilvyas/SOAP """ - use_precond_schedule: bool = False - def __init__( self, params, @@ -548,11 +657,9 @@ def __init__( max_precond_dim: int = 2048, # merge_dims: bool = True, precondition_1d: bool = False, - normalize_grads: bool = False, - correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, - foreach: bool = True, + multi_tensor: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, @@ -563,38 +670,18 @@ def __init__( gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, storage_dtype: str = "float32", - stochastic_schedule: bool = False, precond_grad_accum: bool = False, compile_step: bool = C.use_default, promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): - use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule) - - params, defaults = C._build_defaults(locals()) - - if use_precond_schedule: - del defaults["precondition_frequency"] - self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler")) - else: - del defaults["precond_scheduler"] - self.precond_schedule = 1 / defaults.pop("precondition_frequency") - super().__init__( - params, - defaults, - foreach, - gradient_clipping, - update_clipping, - palm, # - fns=(C.scale_by_soap,), - ) + self._build_soap_defaults(locals(), fns=(C.scale_by_soap,)) -class ForeachSOAPNAdam(C.BaseOpt): - use_precond_schedule: bool = False - +class SOAPNAdam(SOAPBase): def __init__( self, params, @@ -607,11 +694,9 @@ def __init__( max_precond_dim: int = 2048, merge_dims: bool = True, precondition_1d: bool = False, - normalize_grads: bool = False, - correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, - foreach: bool = True, + multi_tensor: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, @@ -622,7 +707,6 @@ def __init__( gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, storage_dtype: str = "float32", - stochastic_schedule: bool = False, precond_grad_accum: bool = False, momentum_decay: float = 4e-3, decoupled_weight_decay: bool = False, @@ -630,32 +714,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): - use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule) + self._build_soap_defaults(locals(), fns=(C.scale_by_soap_nadam,)) - params, defaults = C._build_defaults(locals()) - - if use_precond_schedule: - del defaults["precondition_frequency"] - self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler")) - else: - del defaults["precond_scheduler"] - self.precond_schedule = 1 / defaults.pop("precondition_frequency") - super().__init__( - params, - defaults, - foreach, - gradient_clipping, - update_clipping, - palm, - fns=(C.scale_by_soap_nadam,), - ) - - -class ForeachSOAPAdEMAMix(C.BaseOpt): - use_precond_schedule: bool = False +class SOAPAdEMAMix(SOAPBase): def __init__( self, params, @@ -668,11 +733,9 @@ def __init__( max_precond_dim: int = 2048, merge_dims: bool = True, precondition_1d: bool = False, - normalize_grads: bool = False, - correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, - foreach: bool = True, + multi_tensor: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, @@ -683,7 +746,6 @@ def __init__( gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, storage_dtype: str = "float32", - stochastic_schedule: bool = False, precond_grad_accum: bool = False, alpha: float = 2.0, beta3_warmup: int | None = None, @@ -692,30 +754,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): - use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule) + self._build_soap_defaults(locals(), fns=(C.scale_by_soap_ademamix,)) - params, defaults = C._build_defaults(locals()) - if use_precond_schedule: - del defaults["precondition_frequency"] - self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler")) - else: - del defaults["precond_scheduler"] - self.precond_schedule = 1 / defaults.pop("precondition_frequency") - super().__init__( - params, - defaults, - foreach, - gradient_clipping, - update_clipping, - palm, - fns=(C.scale_by_soap_ademamix,), - ) - - -class ForeachSignLaProp(C.BaseOpt): +class SignLaProp(C.BaseOpt): def __init__( self, params, @@ -724,7 +769,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -737,13 +782,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -751,9 +796,9 @@ def __init__( ) -class ForeachSOLP(C.BaseOpt): +class SOLP(SOAPBase): """ - ForeachSOLP + SOLP Sources: Baseline SOAP: @@ -763,8 +808,6 @@ class ForeachSOLP(C.BaseOpt): https://github.com/nikhilvyas/SOAP """ - use_precond_schedule: bool = False - def __init__( self, params, @@ -777,11 +820,9 @@ def __init__( max_precond_dim: int = 2048, # merge_dims: bool = True, precondition_1d: bool = False, - normalize_grads: bool = False, - correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, - foreach: bool = True, + multi_tensor: bool = True, mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, @@ -792,46 +833,14 @@ def __init__( gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, storage_dtype: str = "float32", - stochastic_schedule: bool = False, compile_step: bool = C.use_default, promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): - use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule) - - params, defaults = C._build_defaults(locals()) - - if use_precond_schedule: - del defaults["precondition_frequency"] - self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler")) - else: - del defaults["precond_scheduler"] - self.precond_schedule = 1 / defaults.pop("precondition_frequency") - super().__init__( - params, - defaults, - foreach, - gradient_clipping, - update_clipping, - palm, # - fns=(C.scale_by_soap_laprop,), - ) - - -class PaLMForeachSOAP(ForeachSOAP): - use_precond_schedule: bool = False - palm: bool = True - - -class PrecondScheduleForeachSOAP(ForeachSOAP): - use_precond_schedule: bool = True - - -class PrecondSchedulePaLMForeachSOAP(ForeachSOAP): - use_precond_schedule: bool = True - palm: bool = True + self._build_soap_defaults(locals(), fns=(C.scale_by_soap_laprop,)) class OrthoLaProp(C.BaseOpt): @@ -843,7 +852,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -856,13 +865,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -879,7 +888,7 @@ def __init__( eps=1e-8, weight_decay=0, warmup_steps=0, - foreach: bool = True, + multi_tensor: bool = True, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -892,13 +901,13 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): params, defaults = C._build_defaults(locals()) super().__init__( params, defaults, - foreach, gradient_clipping, update_clipping, palm, @@ -906,18 +915,43 @@ def __init__( ) -class ForeachPSGDKron(C.BaseOpt): +class PSGDBase(C.BaseOpt): + delayed: bool = False + cached: bool = False + exp_avg_input: bool = True + + def _build_psgd_defaults( + self, locals_dict, fns, *, default_update_clipping=utils.trust_region_clip_, extra_defaults=None + ): + exp_avg_input = C.default(locals_dict.get("exp_avg_input", C.use_default), self.exp_avg_input) + update_clipping = C.default(locals_dict["update_clipping"], default_update_clipping) + locals_dict = {**locals_dict, "exp_avg_input": exp_avg_input, "update_clipping": update_clipping} + params, defaults = C._build_defaults(locals_dict) + + self.precond_schedule = C.default( + defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule() + ) + + if extra_defaults: + defaults.update(extra_defaults) + + super().__init__( + params, + defaults, + locals_dict["gradient_clipping"], + update_clipping, + False, + fns=(*(C.exp_avg,) * exp_avg_input, *fns), + ) + + +class PSGDKron(PSGDBase): """ Originally from Evan Walters and Omead Pooladzandi, 2024 Modified under Creative Commons Attribution 4.0 International Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py """ - delayed: bool = False - cached: bool = False - exp_avg_input: bool = True - quad: bool = False - def __init__( self, params, @@ -934,9 +968,8 @@ def __init__( merge_dims: bool = False, split: bool = False, store_triu_as_line: bool = True, - foreach: bool = True, + multi_tensor: bool = True, q_dtype="float32", - stochastic_schedule: bool = False, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -946,12 +979,9 @@ def __init__( exp_avg_input: Optional[bool] = C.use_default, gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, # - adaptive: bool = False, - ortho_method: Optional[str] = None, # If None, no orthogonalization precond_grad_accum: bool = False, - lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD - inverse_free: bool = C.use_default, - dampening: float = 2**-13, + lower_bound_beta: float = 0.9, + dampening: float = 1e-9, precond_update_power_iterations: int = 2, # expert parameters precond_init_scale=None, @@ -966,77 +996,140 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): delayed = C.default(delayed, self.delayed) cached = C.default(cached, self.cached) - exp_avg_input = C.default(exp_avg_input, self.exp_avg_input) - update_clipping = C.default(update_clipping, utils.trust_region_clip_) - inverse_free = C.default(inverse_free, self.quad) - if inverse_free: - raise ValueError( - "inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch" - ) - - params, defaults = C._build_defaults(locals()) - - self.precond_schedule = C.default( - defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule() + self._build_psgd_defaults( + locals(), + fns=(functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),), ) - super().__init__( - params, - defaults, - foreach, - gradient_clipping, - update_clipping, - False, # - fns=( - *(C.exp_avg,) * exp_avg_input, - functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached), - ), - ) - - -class ForeachPurePSGD(ForeachPSGDKron): - exp_avg_input: bool = False - - -class ForeachCachedDelayedPSGDKron(ForeachPSGDKron): - delayed: bool = True - cached: bool = True - - -class ForeachCachedPSGDKron(ForeachPSGDKron): - cached: bool = True +class LATHER(PSGDBase): + """ + Lie-group Adam Through Harmonic Eigenbasis Rotations. + Runs Adam in the approximate eigenspace induced by the PSGD-Kron preconditioner, then maps back. + """ -class ForeachDelayedPSGD(ForeachPSGDKron): - delayed: bool = True - + def __init__( + self, + params, + lr=0.001, + beta=None, + betas=(0.9, 0.999), + eps: float = 1e-8, + weight_decay=0.0, + preconditioner_update_probability=C.use_default, + max_size_triangular=2048, + min_ndim_triangular=2, + memory_save_mode=None, + momentum_into_precond_update=True, + warmup_steps: int = 0, + merge_dims: bool = False, + split: bool = False, + store_triu_as_line: bool = True, + multi_tensor: bool = True, + q_dtype="float32", + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + precond_grad_accum: bool = False, + lower_bound_beta: float = 0.9, + dampening: float = 1e-9, + precond_update_power_iterations: int = 2, + precond_init_scale=None, + precond_init_scale_scale: float = 1, + precond_init_scale_power: Optional[float] = None, + precond_lr: float = 0.1, + finite_differences: bool = C.use_default, + fallback_to_finite_differences: bool = C.use_default, + hvp_interval: int = C.use_default, + hessian_approx: bool = C.use_default, + compile_step: bool = C.use_default, + promote: bool = C.use_default, + ecc: str | None = None, + param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, + **kwargs, + ): + self._build_psgd_defaults( + {**locals(), "exp_avg_input": False}, + fns=(C.scale_by_lather,), + ) -class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron): - hessian_approx = True +class PSGDPRO(PSGDBase): + """ + PSGD with Q0.5EQ1.5 (PRO/Procrustes) preconditioner update. + Solve-free alternative to standard PSGD-Kron (EQ method). + Reference: https://github.com/lixilinx/psgd_torch + """ -class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD): - hvp_interval = 2 + def __init__( + self, + params, + lr=0.001, + beta=None, + betas=(0.9, 0.999), + weight_decay=0.0, + preconditioner_update_probability=C.use_default, + max_size_triangular=2048, + min_ndim_triangular=2, + memory_save_mode=None, + momentum_into_precond_update=True, + warmup_steps: int = 0, + merge_dims: bool = False, + split: bool = False, + multi_tensor: bool = True, + q_dtype="float32", + storage_dtype: str = "float32", + mars: bool = False, + caution: bool = False, + mars_gamma: float = 0.0025, + cached: Optional[bool] = C.use_default, + exp_avg_input: Optional[bool] = C.use_default, + gradient_clipping: C.str_or_fn = C.use_default, + update_clipping: C.str_or_fn = C.use_default, + precond_grad_accum: bool = False, + lower_bound_beta: float = 0.9, + dampening: float = 1e-9, + precond_update_power_iterations: int = 2, + precond_init_scale=None, + precond_init_scale_scale: float = 1, + precond_init_scale_power: Optional[float] = None, + precond_lr: float = 0.1, + compile_step: bool = C.use_default, + promote: bool = C.use_default, + ecc: str | None = None, + param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, + **kwargs, + ): + cached = C.default(cached, self.cached) + self._build_psgd_defaults( + locals(), + fns=(functools.partial(C.scale_by_psgd_pro, cached=cached),), + default_update_clipping=None, + extra_defaults={"store_triu_as_line": False}, + ) -class ForeachPSGDLRA(C.BaseOpt): +class PSGDLRA(PSGDBase): """ Originally from Evan Walters and Omead Pooladzandi, 2024 Modified under Creative Commons Attribution 4.0 International Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py - Note: `foreach=True` (default) uses a single global low-rank approximation shared across all - parameters, while `foreach=False` fits an independent per-parameter LRA. These are different + Note: `multi_tensor=True` (default) uses a single global low-rank approximation shared across all + parameters, while `multi_tensor=False` fits an independent per-parameter LRA. These are different algorithms and will produce different results. """ - delayed: bool = False - exp_avg_input: bool = True - def __init__( self, params, @@ -1047,9 +1140,8 @@ def __init__( momentum_into_precond_update=True, rank: Optional[int] = None, warmup_steps: int = 0, - foreach: bool = True, # True: global LRA across all params. False: independent per-param LRA. + multi_tensor: bool = True, # True: global LRA across all params. False: independent per-param LRA. q_dtype="float32", - stochastic_schedule: bool = False, storage_dtype: str = "float32", mars: bool = False, caution: bool = False, @@ -1058,8 +1150,8 @@ def __init__( exp_avg_input: Optional[bool] = C.use_default, gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, - eps: float = 1e-8, # - precond_grad_accum: bool = False, # expert parameters + eps: float = 1e-8, + precond_grad_accum: bool = False, precond_init_scale=None, precond_init_scale_scale: float = 1, precond_init_scale_power: Optional[float] = None, @@ -1072,49 +1164,24 @@ def __init__( promote: bool = C.use_default, ecc: str | None = None, param_ecc: str | None = None, + orig_shapes: ShapeMap | None = None, **kwargs, ): delayed = C.default(delayed, self.delayed) - exp_avg_input = C.default(exp_avg_input, self.exp_avg_input) - update_clipping = C.default(update_clipping, utils.trust_region_clip_) - - params, defaults = C._build_defaults(locals()) - - self.precond_schedule = C.default( - defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule() - ) - if rank is None: utils.warn_once( f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}" ) params = list(params) - defaults["rank"] = max(1, round(math.log2(sum(p.numel() for p in params)))) - utils.warn_once(f"rank was set to {defaults['rank']}") + rank = max(1, round(math.log2(sum(p.numel() for p in params)))) + utils.warn_once(f"rank was set to {rank}") - super().__init__( - params, - defaults, - foreach, - gradient_clipping, - update_clipping, - False, # - fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra), + self._build_psgd_defaults( + locals(), + fns=(C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,), ) -class ForeachDelayedPSGDLRA(ForeachPSGDLRA): - delayed: bool = True - - -class ForeachNewtonPSGDLRA(ForeachPSGDLRA): - hessian_approx = True - - -class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA): - hvp_interval = 2 - - class SplitOpt(utils.StatefulOptimizer): """ Delegates different parameter groups to different underlying optimizers. @@ -1135,7 +1202,7 @@ def __init__(self, specs): all_params.extend(params) if not self.optimizers: raise ValueError("No optimizers created") - super().__init__(all_params, {}, foreach=True) + super().__init__(all_params, {"multi_tensor": True}) def _step(self, group): pass @@ -1166,7 +1233,7 @@ class SAMWrapper(torch.optim.Optimizer): def __init__( self, params, - wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW, + wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = AdamW, ball: float = 0.1, ): params = list(params) @@ -1210,31 +1277,10 @@ def zero_grad(self, set_to_none: bool = True): self.wrapped_optimizer.zero_grad(set_to_none=set_to_none) -PalmForEachSoap = PaLMForeachSOAP -PaLMSOAP = PaLMForeachSOAP -PaLMSFAdamW = PaLMForeachSFAdamW -SOAP = ForeachSOAP -SOAPAdEMAMix = ForeachSOAPAdEMAMix -SOAPNAdam = ForeachSOAPNAdam -SFAdamW = ForeachSFAdamW -LaProp = ForeachLaProp -ADOPT = ForeachADOPT -RMSprop = ForeachRMSprop -PrecondScheduleSOAP = PrecondScheduleForeachSOAP -PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP -PSGDKron = ForeachPSGDKron -AdamW = ForeachAdamW -NAdam = ForeachNAdam -PurePSGD = ForeachPurePSGD -DelayedPSGD = ForeachDelayedPSGD -CachedPSGDKron = ForeachCachedPSGDKron -CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron -Muon = ForeachMuon -SignLaProp = ForeachSignLaProp -DelayedPSGDLRA = ForeachDelayedPSGDLRA -PSGDLRA = ForeachPSGDLRA -NewtonPSGDLRA = ForeachNewtonPSGDLRA -NewtonPSGDKron = ForeachCachedNewtonPSGD - capture_param_shapes = utils.capture_param_shapes -__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)] +_BASE_CLASSES = {SOAPBase, PSGDBase} +__all__ = [ + k + for k, v in globals().items() + if isinstance(v, type) and issubclass(v, torch.optim.Optimizer) and v not in _BASE_CLASSES +] diff --git a/heavyball/chainable.py b/heavyball/chainable.py index eca7ea7..100c7c2 100644 --- a/heavyball/chainable.py +++ b/heavyball/chainable.py @@ -2,7 +2,7 @@ import copy import functools import math -import random +import warnings from collections.abc import Iterable as _Iterable from typing import Iterable, List, Literal, Optional, Union @@ -25,13 +25,6 @@ def _key_in_state(state, key): return True -def _inplace_guard_(state, key, template_fn): - key_not_in_state = not _key_in_state(state, key) - if key_not_in_state: - template_fn() - return key_not_in_state - - def _guard_in_state(state, key, template_fn): if not _key_in_state(state, key): state[key] = template_fn() @@ -45,7 +38,6 @@ def __init__(self, fn, names: list[str] | None = None): self.fn = fn self.fn_name = self.get_fn().__name__ self.transform_idx = None - self.is_initialized = False self.names = names def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs): @@ -55,7 +47,7 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs): raise NotImplementedError def __call__(self, state, group, update, grad, param, *args, **kwargs): - states = [state(p) for p in param] + states = state if isinstance(state, list) else [state(p) for p in param] skip_update = False for st, a in zip(states, zip(update, grad, param, *args)): if self.transform_idx not in st.get("is_initialized", set()): @@ -77,28 +69,106 @@ def get_fn(self): return self.fn.get_fn() return self.fn + def _build_val_names(self): + self._val_names = {name: f"{self.fn_name}_{name}_{self.transform_idx}" for name in self.names} + def val_name(self, name): - assert self.transform_idx is not None - return f"{self.fn_name}_{name}_{self.transform_idx}" + return self._val_names[name] def __repr__(self): return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})" -class Branch: +def _enforce_uniform_skip(results): + skips = [skip for _, skip, _ in results] + if not skips: + return False + if all(skips): + return True + if not any(skips): + return False + raise ValueError("All branches must uniformly skip or not skip updates") + + +def _normalize_chain(fns): + if fns is None: + return None + return fns if isinstance(fns, (list, tuple)) else (fns,) + + +class Parallel: def __init__(self, branches: List[List[callable]], merge_fn: callable): self.branches = branches self.merge_fn = merge_fn def __call__(self, state, group, update, grad, param): - outputs = [] + results = [] for branch in self.branches: branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update] - branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch) - if skip_update: - raise ValueError("Branches should not skip updates") - outputs.append(branch_update) - return self.merge_fn(outputs) + u, skip = _inner_chain(state, group, branch_update, grad, param, *branch) + results.append((u, skip, None)) + if _enforce_uniform_skip(results): + raise SkipUpdate from None + return self.merge_fn([u for u, _, _ in results]) + + +class Route: + """Route params by predicate through different fn chains. + + Takes arbitrary (predicate, fns) pairs. Each param is assigned to the first + matching route; unmatched params use the default chain (None = passthrough). + All routes must uniformly either skip or not skip updates. + """ + + def __init__(self, *routes, default=None): + self.routes = [(pred, _normalize_chain(fns)) for pred, fns in routes] + self.default = _normalize_chain(default) + + def __call__(self, state, group, update, grad, param): + buckets = {} + assigned = set() + for j, (pred, _) in enumerate(self.routes): + for i, p in enumerate(param): + if i not in assigned and pred(p): + buckets.setdefault(j, []).append(i) + assigned.add(i) + default_idx = [i for i in range(len(param)) if i not in assigned] + + def _sel(lst, idx): + return [lst[i] for i in idx] + + caution = group["caution"] + results = [] + + all_chains = [(buckets.get(j), fns) for j, (_, fns) in enumerate(self.routes)] + if default_idx: + all_chains.append((default_idx, self.default)) + + for idx, fns in all_chains: + if not idx: + continue + group["caution"] = caution + if fns is not None: + u, skip = _inner_chain( + _sel(state, idx), group, _sel(update, idx), _sel(grad, idx), _sel(param, idx), *fns + ) + else: + u, skip = _sel(update, idx), False + results.append((u, skip, idx)) + + if _enforce_uniform_skip(results): + raise SkipUpdate from None + + out = [None] * len(param) + for u_list, _, idx in results: + if u_list is not None: + for i, u in zip(idx, u_list): + out[i] = u + return out + + +def route(*routes, default=None): + return Route(*routes, default=default) def _zero_guard(state, key, ref, dtype): @@ -112,12 +182,38 @@ def _storage_dtype(group): _PASSTHROUGH_KWARGS = {"orig_shapes"} +_RENAMED_KWARGS = {"foreach": "multi_tensor"} + +_REMOVED_KWARGS = frozenset( + { + "stochastic_schedule", + "normalize_grads", + "correct_bias", + "inverse_free", + "adaptive", + } +) + def _build_defaults(locals_dict): d = locals_dict.copy() - d.pop("self") + d.pop("self", None) params = d.pop("params") kwargs = d.pop("kwargs") + + for old, new in _RENAMED_KWARGS.items(): + if old in kwargs: + warnings.warn( + f"'{old}' was renamed to '{new}' in HeavyBall 3.0. Pass '{new}' instead.", FutureWarning, stacklevel=4 + ) + d[new] = kwargs.pop(old) + + hit = _REMOVED_KWARGS & kwargs.keys() + if hit: + raise TypeError( + f"Removed in HeavyBall 3.0: {', '.join(sorted(hit))}. See docs/heavyball3.md for migration details." + ) + d.update(kwargs) unknown = {k: v for k, v in kwargs.items() if k not in _PASSTHROUGH_KWARGS} if unknown: @@ -197,12 +293,11 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs): if ecc is None: return self.fn(state, group, update, grad, param, *args, *vars, **kwargs) - states = [state(p) for p in param] names = [self.val_name(n) for n in self.names] - primary_vars = [[st[vn] for st in states] for vn in names] + primary_vars = [[st[vn] for st in state] for vn in names] with contextlib.ExitStack() as stack: for vn, plist in zip(names, primary_vars): - corrs = [st[vn + "::ecc"] for st in states] + corrs = [st[vn + "::ecc"] for st in state] stack.enter_context(ecc.attached(plist, corrs)) return self.fn(state, group, update, grad, param, *args, *primary_vars, **kwargs) @@ -210,45 +305,50 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs): class PrecondGradAccumGuard(FunctionTransform): def __init__(self, fn): super().__init__(fn, ["precond_grad_accum"]) - self.steps_taken = 0 - self.pass_through = None + self.steps_taken_key = None + + def _build_val_names(self): + super()._build_val_names() + self.steps_taken_key = f"_{self.fn_name}_steps_taken_{self.transform_idx}" - def _accum(self, state, new): - self.steps_taken += 1 + def _accum(self, group, state, new): + group[self.steps_taken_key] = group.get(self.steps_taken_key, 0) + 1 utils.stochastic_add_(state, new) - def _reset(self, state): - if self.steps_taken != 0: - self.steps_taken = 0 + def _reset(self, group, state): + if group.get(self.steps_taken_key, 0) != 0: + group[self.steps_taken_key] = 0 utils.zero_(state) def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs): - if self.pass_through is None: - self.pass_through = not group.get("precond_grad_accum", False) - if self.pass_through is False: - for name in self.names: - _zero_guard(state, self.val_name(name), param, _storage_dtype(group)) + if not group.get("precond_grad_accum", False): + return + for name in self.names: + _zero_guard(state, self.val_name(name), param, _storage_dtype(group)) def _call(self, state, group, update, grad, param, vars, *args, **kwargs): base_grad = update if group.get("momentum_into_precond_update", True) else grad - if self.pass_through: + if not group.get("precond_grad_accum", False): return self.fn(state, group, update, grad, param, *args, base_grad, **kwargs) (vars,) = vars + steps_taken = group.get(self.steps_taken_key, 0) + accum_state = None if group["is_preconditioning"]: - if self.steps_taken: - self._accum(vars, base_grad) - utils.stochastic_multiply_(vars, 1 / self.steps_taken) + if steps_taken: + self._accum(group, vars, base_grad) + utils.stochastic_multiply_(vars, 1 / group[self.steps_taken_key]) + accum_state = vars else: vars = base_grad else: - self._accum(vars, base_grad) + self._accum(group, vars, base_grad) vars = base_grad try: out = self.fn(state, group, update, grad, param, *args, vars, **kwargs) finally: - if group["is_preconditioning"]: - self._reset(vars) + if accum_state is not None: + self._reset(group, accum_state) return out @@ -272,17 +372,6 @@ def __init__(self, fn, names, init_fn, skip_first: bool = True): super().__init__(fn, names) self.init_fn = init_fn self.skip_first = skip_first - self.named_to_anonymous = None - self.anonymous_to_named = None - - def _map(self, state_fn, param, mapping): - for p in param: - state = state_fn(p) - for name, mapped in mapping.items(): - if mapped in state: - raise ValueError(f"Name {name} already mapped to {mapped}") - if name in state: - state[mapped] = state.pop(name) def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs): self.init_fn(state, group, update, grad, param, **kwargs) @@ -296,12 +385,19 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs): class NoState(FunctionTransform): + needs_init = False + def __call__(self, state, group, update, grad, param, *args, **kwargs): return self.fn(group, update, grad, param, *args, **kwargs) -class NoStateNoForeach(FunctionTransform): +class NoStateNoMultiTensor(FunctionTransform): def __call__(self, state, group, update, grad, param, *args, **kwargs): + states = state if isinstance(state, list) else [state(p) for p in param] + for st in states: + if "is_initialized" not in st: + st["is_initialized"] = set() + st["is_initialized"].add(self.transform_idx) updates = [] skip_update = False for a in zip(update, grad, param, *args): @@ -324,6 +420,8 @@ def _view_preserve_ecc(src, target): class SqueezeGrad(FunctionTransform): + needs_init = False + def __call__(self, state, group, update, grad, param, *args, **kwargs): original_shapes = [u.shape for u in update] update = [u.squeeze() if u.numel() > 1 else u.view(-1) for u in update] @@ -353,6 +451,32 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs): return self.fn(state, group, update, grad, param, *args, **kwargs) +class WarmupGuard(FunctionTransform): + def __init__(self, fn, warmup_fns): + super().__init__(fn, names=[]) + self.warmup_fns = warmup_fns + self.warmup_key = None + + def _build_val_names(self): + super()._build_val_names() + self.warmup_key = f"_warmup_{self.transform_idx}" + + def __call__(self, state, group, update, grad, param, *args, **kwargs): + states = state if isinstance(state, list) else [state(p) for p in param] + warmup_step = min(st.get(self.warmup_key, 0) for st in states) + if warmup_step < len(self.warmup_fns): + fn = self.warmup_fns[warmup_step] + for st, a in zip(states, zip(update, grad, param, *args)): + fn(st, group, *a, **kwargs) + st[self.warmup_key] = st.get(self.warmup_key, 0) + 1 + raise SkipUpdate from None + for st in states: + if "is_initialized" not in st: + st["is_initialized"] = set() + st["is_initialized"].add(self.transform_idx) + return self.fn(state, group, update, grad, param, *args, **kwargs) + + needs_full_param = functools.partial(TagGuard, needs_full_param=True) @@ -368,12 +492,16 @@ def general_guard(*names, init_fn, skip_first: bool = True): return functools.partial(GeneralGuard, names=names, init_fn=init_fn, skip_first=skip_first) +def warmup_guard(*warmup_fns): + return functools.partial(WarmupGuard, warmup_fns=list(warmup_fns)) + + def no_state(fn): return NoState(fn) -def no_state_no_foreach(fn): - return NoStateNoForeach(fn) +def no_state_no_multi_tensor(fn): + return NoStateNoMultiTensor(fn) class SkipUpdate(ValueError): @@ -405,6 +533,12 @@ def identity(state, group, update, grad, param): return update +@no_state +def apply_update(group, update, grad, param): + utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad) + raise SkipUpdate from None + + @zero_guard("exp_avg") @no_state def weight_decay_to_ema(group, update, grad, param, exp_avg): @@ -619,14 +753,14 @@ def orthogonalize_grad_to_param(group, update, grad, param): @copy_guard(2, "z") @no_state def update_by_schedule_free(group, update, grad, param, z): - # Compute weight_sum once per step, not per param in no-foreach mode. - if group.get("_sf_step") != group["step"]: - weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"] + # Compute weight_sum once per step, not per param in no-multi_tensor mode. + if group.get("_sf_step") is not group["step"]: + weight = abs(group["lr"]) ** group["weight_lr_power"] * group["step"].clamp(min=1) ** group["r"] group["weight_sum"] = group.get("weight_sum", 0) + weight group["_sf_step"] = group["step"] weight_sum = group["weight_sum"] - weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"] + weight = abs(group["lr"]) ** group["weight_lr_power"] * group["step"].clamp(min=1) ** group["r"] try: ckp1 = weight / weight_sum except ZeroDivisionError: @@ -658,28 +792,23 @@ def update_by_msam(group, update, grad, param, z, exp_avg): raise SkipUpdate from None +def _adopt_warmup_1(state, group, update, grad, param, exp_avg, exp_avg_sq): + utils.scale_by_exp_avg_sq_([exp_avg_sq], [update], 0, group["eps"]) + + +def _adopt_warmup_2(state, group, update, grad, param, exp_avg, exp_avg_sq): + u = utils.promote(update) + easq = utils.promote(exp_avg_sq) + utils.copy_stochastic_(exp_avg, u / easq.sqrt().clamp_(min=group["eps"])) + utils.scale_by_exp_avg_sq_( + [exp_avg_sq], [update], utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"] + ) + + @zero_guard("exp_avg", "exp_avg_sq") +@warmup_guard(_adopt_warmup_1, _adopt_warmup_2) @no_state def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq): - if group["step"] == 1: - utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"]) - raise SkipUpdate from None - - if group["step"] == 2: - update = utils.promote(update) - easq = utils.promote(exp_avg_sq) - [ - utils.copy_stochastic_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) - for ea, u, easq_ in zip(exp_avg, update, easq) - ] - utils.scale_by_exp_avg_sq_( - exp_avg_sq, - update, - utils.beta_debias(utils.get_beta2(group), group["step"]), - group["eps"], - ) - raise SkipUpdate from None - utils.fused_adopt_( param, update, @@ -697,14 +826,15 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq): raise SkipUpdate from None +def _suds_warmup_1(state, group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx): + utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8)) + + @needs_full_param @zero_guard("exp_avg", "exp_avg_sq", "fisher_approx") -@no_state_no_foreach +@warmup_guard(_suds_warmup_1) +@no_state_no_multi_tensor def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx): - if group["step"] == 1: - utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8)) - raise SkipUpdate from None - precond_update, w = utils.eigvecs_product_rank1(update.flatten(), fisher_approx.flatten().to(update.dtype)) precond_update = utils.adam_( exp_avg, @@ -737,27 +867,9 @@ def scale_by_unscaled_adam(group, update, grad, param, exp_avg, exp_avg_sq): @zero_guard("exp_avg", "exp_avg_sq") +@warmup_guard(_adopt_warmup_1, _adopt_warmup_2) @no_state def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq): - if group["step"] == 1: - utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"]) - raise SkipUpdate from None - - if group["step"] == 2: - update = utils.promote(update) - easq = utils.promote(exp_avg_sq) - [ - utils.copy_stochastic_(ea, u / easq_.sqrt().clamp_(min=group["eps"])) - for ea, u, easq_ in zip(exp_avg, update, easq) - ] - utils.scale_by_exp_avg_sq_( - exp_avg_sq, - update, - utils.beta_debias(utils.get_beta2(group), group["step"]), - group["eps"], - ) - raise SkipUpdate from None - return utils.adopt( update, exp_avg_sq, @@ -769,6 +881,7 @@ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq): def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None): + tmp = utils.get_temporary(group, param) or {} Q = utils.init_Q_exprs( grad, group["precond_init_scale"], @@ -777,22 +890,66 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro group["max_size_triangular"], group["min_ndim_triangular"], group["memory_save_mode"], - getattr(param, "hessian_vector", None), - getattr(param, "vector", None), + tmp.get("hessian_vector"), + tmp.get("vector"), dtype=getattr(torch, group["q_dtype"]), ) state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q] - state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load - if group["adaptive"]: - state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q] + state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) if not cached: return state["Q_cache"] = [torch.empty_like(q) for q in Q] +def _init_psgd_eigen_kron(state, group, update, grad, param, prob: Optional[callable] = None): + tmp = utils.get_temporary(group, param) or {} + Q = utils.init_Q_exprs( + grad, + group["precond_init_scale"], + group["precond_init_scale_scale"], + group["precond_init_scale_power"], + group["max_size_triangular"], + group["min_ndim_triangular"], + group["memory_save_mode"], + tmp.get("hessian_vector"), + tmp.get("vector"), + dtype=getattr(torch, group["q_dtype"]), + ) + state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q] + state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) + + line, group["store_triu_as_line"] = group["store_triu_as_line"], False + _update_psgd_precond(False, None, group, param, update, Q, state["running_lower_bound"], state["step"], prob) + group["store_triu_as_line"] = line + state["Q"] = utils.triu_to_line(Q) if line else Q + state["Q_basis"] = utils.init_psgd_eigenbasis(Q) + + +def _init_psgd_pro_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None): + Q = utils.init_Q_exprs( + grad, + group["precond_init_scale"], + group["precond_init_scale_scale"], + group["precond_init_scale_power"], + group["max_size_triangular"], + group["min_ndim_triangular"], + group["memory_save_mode"], + None, + None, + dtype=getattr(torch, group["q_dtype"]), + ) + state["Q"] = Q + state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q] + state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) + if not cached: + return + state["Q_cache"] = [torch.empty_like(q) for q in Q] + + def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None): + tmp = utils.get_temporary(group, param) or {} state["U"], state["V"], state["d"] = utils.init_lra( grad, group["param_count"], @@ -800,30 +957,14 @@ def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob group["precond_init_scale_scale"], group["precond_init_scale_power"], group["rank"], - getattr(param, "hessian_vector", None), - getattr(param, "vector", None), + tmp.get("hessian_vector"), + tmp.get("vector"), dtype=getattr(torch, group["q_dtype"]), ) -def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"): - step = group["step"] - if "precondition_frequency" in group: - return step > 0 and step % group["precondition_frequency"] == 0 - if isinstance(step, torch.Tensor): - utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.") - rng = random.Random(0x172381) - else: - rng = random.Random(0x172381 ^ step) - if "precond_scheduler" in group: - return utils.precond_schedule(step, group["precond_scheduler"], rng) - if prob is not None: - return utils.psgd_should_update(group, prob, rng, name=name) - raise ValueError("No preconditioner update schedule specified.") - - @needs_full_param -@no_state_no_foreach +@no_state_no_multi_tensor def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft" if update.dim() < 2: return update @@ -846,8 +987,20 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok return utils.nesterov_ema(momentum, updates, utils.get_beta1(group)) +def _store_init_norm(state, group, update, grad, param): + state["init_norm"] = param.to(_storage_dtype(group)).norm() + + +@needs_full_param +@general_guard("init_norm", init_fn=_store_init_norm, skip_first=False) +@no_state +def update_by_hyperball(group, update, grad, param, init_norm): + utils.hyperball_step_(param, update, init_norm, group["lr"], group["weight_decay"], group["caution"], grad) + raise SkipUpdate from None + + def _store_std(state, group, update, grad, param): - state["init_std"] = torch.std(param) + state["init_std"] = torch.std(param.to(_storage_dtype(group))) @needs_full_param @@ -1032,34 +1185,27 @@ def _update_psgd_precond( param, grad, Q, - velocity, running_lower_bound, step, prob: Optional[callable] = None, -) -> Optional[Tensor]: +) -> None: if prob is None: prob = utils.precond_update_prob_schedule() if not group["is_preconditioning"]: return - if utils.hasattr_none(param, "vector"): - vector, hessian_vector = param.vector, param.hessian_vector - del param.vector - del param.hessian_vector - elif group["inverse_free"]: - vector, hessian_vector = None, grad - else: + if (utils.get_temporary(group, param) or {}).get("vector") is None: vector, hessian_vector = utils.dampen_grad(grad, group["dampening"]) + else: + vector, hessian_vector = utils.take_temporary(group, param, "vector", "hessian_vector") - precond = utils.psgd_update_precond( + utils.psgd_update_precond( hessian_vector, group["precond_lr"], Q, group["store_triu_as_line"], - velocity, utils.get_beta2(group), - group["ortho_method"], vector, running_lower_bound, group["lower_bound_beta"], @@ -1073,12 +1219,10 @@ def _update_psgd_precond( float_prob = prob(group["step"]) group["is_cached"] = should_use_cache = cached and float_prob < 0.5 - if precond is not None: - return precond if not should_use_cache or not cached: - return None # caching adds extra ops and is not worth the overhead when we precondition at every step + return - Q_resolved = utils.line_to_triu(Q, group["inverse_free"]) if group["store_triu_as_line"] else Q + Q_resolved = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q for i, (c_, q_) in enumerate(zip(Q_cache, Q_resolved)): if c_ is None: c_ = ( @@ -1091,7 +1235,57 @@ def _update_psgd_precond( torch.matmul(q_.T, q_, out=c_) else: torch.mul(q_, q_, out=c_) - return None + return + + +def _update_psgd_pro_precond( + cached, + Q_cache, + group, + param, + grad, + Q, + running_lower_bound, + step, + prob: Optional[callable] = None, +) -> None: + if prob is None: + prob = utils.precond_update_prob_schedule() + + if not group["is_preconditioning"]: + return + + utils.psgd_pro_update_precond( + grad, + group["precond_lr"], + Q, + running_lower_bound, + group["lower_bound_beta"], + group["precond_update_power_iterations"], + group["dampening"], + ) + + if isinstance(prob, float): + float_prob = prob + else: + float_prob = prob(group["step"]) + group["is_cached"] = should_use_cache = cached and float_prob < 0.5 + + if not should_use_cache or not cached: + return + + for i, (c_, q_) in enumerate(zip(Q_cache, Q)): + if c_ is None: + c_ = ( + torch.empty_like(q_) + if q_.ndim == 1 + else torch.empty(q_.shape[0], q_.shape[0], device=q_.device, dtype=q_.dtype) + ) + Q_cache[i] = c_ + if q_.ndim == 2: + torch.matmul(q_.T, q_, out=c_) + else: + torch.mul(q_, q_, out=c_) def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad): @@ -1099,9 +1293,7 @@ def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad): if group.get("is_cached", False) and Q_cache[0] is not None: out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs) else: - out = utils.psgd_precond_grad( - preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs - ) + out = utils.psgd_precond_grad(preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs) group["caution"] = False # we already cautioned here - shouldn't do it again return out @@ -1118,9 +1310,7 @@ def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache): if group.get("is_cached", False) and Q_cache[0] is not None: utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs) else: - utils.fused_psgd_precond_grad( - preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs - ) + utils.fused_psgd_precond_grad(preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs) def _update_lra( @@ -1129,12 +1319,10 @@ def _update_lra( if not group["is_preconditioning"]: return utils.multi_flatten((U, 1), (V, 1), (d, 0)) - if utils.hasattr_none(params[0], "hessian_vector"): - vector = utils.flatten([p.vector for p in params]) - hessian_vector = utils.flatten([p.hessian_vector for p in params]) - for p in params: - del p.vector - del p.hessian_vector + if (utils.get_temporary(group, params[0]) or {}).get("hessian_vector") is not None: + vector_hv = [utils.take_temporary(group, p, "vector", "hessian_vector") for p in params] + vector = utils.flatten([v for v, _ in vector_hv]) + hessian_vector = utils.flatten([hv for _, hv in vector_hv]) else: vector, hessian_vector = utils.dampen_multiple(grads) precond_step = group["precond_step"] = group.get("precond_step", -1) + 1 @@ -1188,8 +1376,8 @@ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U, @needs_full_param @SqueezeGrad @PrecondGradAccumGuard -@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) -@no_state_no_foreach +@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) +@no_state_no_multi_tensor def scale_by_psgd( group, update, @@ -1198,21 +1386,63 @@ def scale_by_psgd( update_to_precond, Q, Q_cache, - velocity: Optional[List[Tensor]], running_lower_bound: List[Tensor], step: Tensor, cached: bool = False, prob: Optional[callable] = None, ): - _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob) + _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob) return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad) @needs_full_param @SqueezeGrad @PrecondGradAccumGuard -@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) -@no_state_no_foreach +@zero_guard("exp_avg", "exp_avg_sq") +@general_guard("Q", "Q_basis", "running_lower_bound", "step", init_fn=_init_psgd_eigen_kron, skip_first=True) +@no_state_no_multi_tensor +def scale_by_lather( + group, + update, + grad, + param, + update_to_precond, + exp_avg, + exp_avg_sq, + Q, + Q_basis, + running_lower_bound: List[Tensor], + step: Tensor, + prob: Optional[callable] = None, +): + projected = utils.project(utils.promote(update), Q_basis, False) + precond = utils.adam_( + exp_avg, + exp_avg_sq, + projected, + utils.get_beta1(group), + utils.get_beta2(group), + group["step"], + group["eps"], + )[0] + precond = utils.project(precond, Q_basis, True) + + if group["is_preconditioning"]: + _update_psgd_precond(False, None, group, param, update_to_precond, Q, running_lower_bound, step, prob) + utils.update_psgd_eigenbasis( + utils.line_to_triu(Q) if group["store_triu_as_line"] else Q, + Q_basis, + exp_avg, + ) + + return precond + + +@needs_full_param +@SqueezeGrad +@PrecondGradAccumGuard +@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) +@no_state_no_multi_tensor def scale_by_delayed_psgd( group, update, @@ -1221,27 +1451,21 @@ def scale_by_delayed_psgd( update_to_precond, Q, Q_cache, - velocity: Optional[List[Tensor]], running_lower_bound: List[Tensor], step: Tensor, cached: bool = False, prob: Optional[callable] = None, ): - if group.get("inverse_free", False): - precond = None - else: - precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad) - new = _update_psgd_precond( - cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob - ) - return new if precond is None else precond + precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad) + _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob) + return precond @needs_full_param @SqueezeGrad @PrecondGradAccumGuard -@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) -@no_state_no_foreach +@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) +@no_state_no_multi_tensor def update_by_psgd( group, update, @@ -1250,13 +1474,12 @@ def update_by_psgd( update_to_precond, Q, Q_cache, - velocity: Optional[List[Tensor]], running_lower_bound: List[Tensor], step: Tensor, cached: bool = False, prob: Optional[callable] = None, ): - _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob) + _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob) _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache) raise SkipUpdate from None @@ -1276,8 +1499,8 @@ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None): @needs_full_param @SqueezeGrad @PrecondGradAccumGuard -@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) -@no_state_no_foreach +@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False) +@no_state_no_multi_tensor def update_by_delayed_psgd( group, update, @@ -1286,14 +1509,58 @@ def update_by_delayed_psgd( update_to_precond, Q, Q_cache, - velocity: Optional[List[Tensor]], running_lower_bound: List[Tensor], step: Tensor, cached: bool = False, prob: Optional[callable] = None, ): _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache) - _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob) + _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob) + raise SkipUpdate from None + + +@needs_full_param +@SqueezeGrad +@PrecondGradAccumGuard +@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False) +@no_state_no_multi_tensor +def scale_by_psgd_pro( + group, + update, + grad, + param, + update_to_precond, + Q, + Q_cache, + running_lower_bound: List[Tensor], + step: Tensor, + cached: bool = False, + prob: Optional[callable] = None, +): + _update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob) + return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad) + + +@needs_full_param +@SqueezeGrad +@PrecondGradAccumGuard +@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False) +@no_state_no_multi_tensor +def update_by_psgd_pro( + group, + update, + grad, + param, + update_to_precond, + Q, + Q_cache, + running_lower_bound: List[Tensor], + step: Tensor, + cached: bool = False, + prob: Optional[callable] = None, +): + _update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob) + _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache) raise SkipUpdate from None @@ -1304,6 +1571,12 @@ def palm_beta2(state, group, update, grad, param): def apply_to_idx(fn, idx): + name = fn + if isinstance(fn, str): + fn = getattr(utils, fn, None) + if fn is None or not callable(fn): + raise ValueError(f"Unknown function '{name}'") + def _fn(state, group, update, grad, param): args = [state, group, update, grad, param] return fn(args[idx]) @@ -1312,15 +1585,70 @@ def _fn(state, group, update, grad, param): return _fn +_FSDP_HEADER_WIDTH = 4 +_FSDP_BUCKET_BYTES = 32 << 20 +_FSDP_DTYPE_CODES = { + torch.float64: 0, + torch.float32: 1, + torch.float16: 2, + torch.bfloat16: 3, + torch.int64: 4, + torch.int32: 5, + torch.int16: 6, + torch.int8: 7, + torch.uint8: 8, + torch.bool: 9, +} + + class _ShapeInfo: - __slots__ = ("orig_shape", "offset", "total", "group", "owner") + __slots__ = ("orig_shape", "offset", "total", "group", "owner", "param_idx") - def __init__(self, orig_shape, offset=0, total=None, group=None, owner=None): + def __init__(self, orig_shape, offset=0, total=None, group=None, owner=None, param_idx=None): self.orig_shape = orig_shape self.offset = offset self.total = total if total is not None else math.prod(orig_shape) self.group = group self.owner = owner + self.param_idx = param_idx + + +class _FSDPBucket: + __slots__ = ("device", "dtype", "send_entries", "send_splits", "recv_entries", "recv_splits") + + def __init__(self, device, dtype, send_entries, send_splits, recv_entries, recv_splits): + self.device = device + self.dtype = dtype + self.send_entries = send_entries + self.send_splits = send_splits + self.recv_entries = recv_entries + self.recv_splits = recv_splits + + +class _FSDPState: + __slots__ = ("items", "buckets") + + def __init__(self, items, buckets): + self.items = items + self.buckets = buckets + + +def _dtype_code(dtype): + if dtype not in _FSDP_DTYPE_CODES: + raise TypeError(f"Unsupported FSDP shard dtype: {dtype}") + return _FSDP_DTYPE_CODES[dtype] + + +def _assign_fsdp_owners(entries, shard_sizes, world_size): + loads = [0] * world_size + owners = [] + for i, (p, _, total, _) in enumerate(entries): + active = shard_sizes[i].nonzero().squeeze(-1).tolist() + candidates = active or list(range(world_size)) + owner = min(candidates, key=loads.__getitem__) + loads[owner] += total * p.element_size() + owners.append(owner) + return owners def _detect_orig_shapes(params): @@ -1343,8 +1671,6 @@ def _detect_orig_shapes(params): for param, spi, shape in zip(obj._params, obj._shard_param_infos, obj._shapes): lookup[id(param)] = (tuple(shape), spi) - _dist = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1 - # optimizer param order is stable across ranks fsdp_entries = [ (p, s, math.prod(s), spi) @@ -1352,47 +1678,184 @@ def _detect_orig_shapes(params): for s, spi in [lookup.get(id(p), (None, None))] if id(p) in fsdp_ids and s is not None ] - - groups = {} - if _dist and fsdp_entries: - rank = torch.distributed.get_rank() - ws = torch.distributed.get_world_size() - n = len(fsdp_entries) - flags = torch.zeros(n, ws, dtype=torch.int32, device=fsdp_entries[0][0].device) - for i, (p, orig, total, spi) in enumerate(fsdp_entries): - if spi.in_shard and spi.numel_in_shard is not None and spi.numel_in_shard < total: - flags[i, rank] = 1 - torch.distributed.all_reduce(flags) - pg_cache = {} - for i in range(n): - ranks = flags[i].nonzero().squeeze(-1).tolist() - if len(ranks) >= 2: - key = tuple(ranks) - if key not in pg_cache: - pg_cache[key] = torch.distributed.new_group(ranks) - groups[i] = (pg_cache[key], ranks) - - # owner must be precomputed (different ranks see different subsets in _reshape_params) result = {} - split_idx = 0 - for i, (p, orig, total, spi) in enumerate(fsdp_entries): - sg = groups.get(i) - if sg: - pg, ranks = sg - owner = ranks[split_idx % len(ranks)] - split_idx += 1 - if spi.in_shard: - result[id(p)] = _ShapeInfo(orig, spi.intra_param_start_idx, total, pg, owner) - elif spi.in_shard: - result[id(p)] = _ShapeInfo(orig, spi.intra_param_start_idx, total) + ws = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + if ws > 1 and fsdp_entries: + rank = torch.distributed.get_rank() + shard_sizes = torch.zeros(len(fsdp_entries), ws, dtype=torch.int64, device=fsdp_entries[0][0].device) + for i, (p, _, _, spi) in enumerate(fsdp_entries): + shard_sizes[i, rank] = p.numel() if spi.in_shard else 0 + torch.distributed.all_reduce(shard_sizes) + owners = _assign_fsdp_owners(fsdp_entries, shard_sizes, ws) + else: + owners = [None] * len(fsdp_entries) + for param_idx, ((p, orig, total, spi), owner) in enumerate(zip(fsdp_entries, owners)): + offset = 0 if spi.intra_param_start_idx is None else spi.intra_param_start_idx + result[id(p)] = _ShapeInfo(orig, offset, total, owner=owner, param_idx=param_idx) return result -def _reduce_gather(shard, offset, total, pg, dst): - full = shard.new_zeros(total) - full[offset : offset + shard.numel()] = shard - torch.distributed.reduce(full, dst=dst, group=pg) - return full +def _exchange_split_sizes(splits, device): + send = torch.tensor(splits, dtype=torch.int64, device=device) + recv = torch.empty_like(send) + torch.distributed.all_to_all_single(recv, send) + return recv.tolist() + + +def _all_to_all_variable(sendbuf, recv_splits, send_splits): + recv = sendbuf.new_empty(sum(recv_splits)) + torch.distributed.all_to_all_single(recv, sendbuf, output_split_sizes=recv_splits, input_split_sizes=send_splits) + return recv + + +def _fsdp_bucket_schedule(items): + buckets, current, lookup = [], {}, {} + for p, info, _ in items: + key = (p.device, p.dtype) + idx = current.get(key) + size = info.total * p.element_size() + if idx is None or (buckets[idx][2] and buckets[idx][2] + size > _FSDP_BUCKET_BYTES): + idx = len(buckets) + buckets.append([p.device, p.dtype, 0]) + current[key] = idx + buckets[idx][2] += size + lookup[info.param_idx] = idx + return [(device, dtype) for device, dtype, _ in buckets], lookup + + +def _exchange_fsdp_shards(schedule, bucket_lookup, items, tensor_getter, keep_state=False): + ws = torch.distributed.get_world_size() + per_bucket = [[] for _ in schedule] + for p, info, shard in items: + tensor = tensor_getter(p, shard) + if tensor is None or tensor.numel() == 0: + continue + flat = tensor.reshape(-1) + bucket_idx = bucket_lookup[info.param_idx] + device, dtype = schedule[bucket_idx] + if flat.device != device or flat.dtype != dtype: + raise RuntimeError( + f"FSDP bucket mismatch for param {info.param_idx}: expected {(device, dtype)}, got {(flat.device, flat.dtype)}" + ) + per_bucket[bucket_idx].append((info.owner, info.param_idx, info.offset, flat, shard)) + + received, states = {}, [] + for (device, dtype), bucket_entries in zip(schedule, per_bucket): + by_dst = [[] for _ in range(ws)] + for entry in bucket_entries: + by_dst[entry[0]].append(entry) + + send_meta_splits = [len(dst_entries) * _FSDP_HEADER_WIDTH for dst_entries in by_dst] + send_payload_splits = [sum(flat.numel() for _, _, _, flat, _ in dst_entries) for dst_entries in by_dst] + recv_meta_splits = _exchange_split_sizes(send_meta_splits, device) + recv_payload_splits = _exchange_split_sizes(send_payload_splits, device) + + code = _dtype_code(dtype) + meta = [ + value + for dst_entries in by_dst + for _, param_idx, offset, flat, _ in dst_entries + for value in (param_idx, offset, flat.numel(), code) + ] + payload = [flat for dst_entries in by_dst for _, _, _, flat, _ in dst_entries] + send_meta = ( + torch.tensor(meta, dtype=torch.int64, device=device) + if meta + else torch.empty(0, dtype=torch.int64, device=device) + ) + send_payload = torch.cat(payload) if payload else torch.empty(0, dtype=dtype, device=device) + + recv_meta = _all_to_all_variable(send_meta, recv_meta_splits, send_meta_splits) + recv_entries = [[] for _ in range(ws)] + meta_offset = 0 + for src, count in enumerate(recv_meta_splits): + if count == 0: + continue + if count % _FSDP_HEADER_WIDTH: + raise RuntimeError(f"Malformed FSDP metadata split: {count}") + rows = recv_meta[meta_offset : meta_offset + count].view(-1, _FSDP_HEADER_WIDTH).cpu().tolist() + meta_offset += count + for param_idx, offset, length, got in rows: + if got != code: + raise RuntimeError(f"FSDP dtype mismatch for bucket {dtype}: expected {code}, got {got}") + recv_entries[src].append((param_idx, offset, length)) + + recv_payload = _all_to_all_variable(send_payload, recv_payload_splits, send_payload_splits) + payload_offset = 0 + for src_entries in recv_entries: + for param_idx, offset, length in src_entries: + chunk = recv_payload[payload_offset : payload_offset + length] + received.setdefault(param_idx, []).append((offset, chunk)) + payload_offset += length + if payload_offset != recv_payload.numel(): + raise RuntimeError("FSDP payload unpack mismatch") + + if keep_state: + states.append(_FSDPBucket(device, dtype, by_dst, send_payload_splits, recv_entries, recv_payload_splits)) + + return received, states + + +def _reshape_fsdp_params(items): + rank = torch.distributed.get_rank() + schedule, bucket_lookup = _fsdp_bucket_schedule(items) + params, buckets = _exchange_fsdp_shards(schedule, bucket_lookup, items, lambda _, shard: shard, keep_state=True) + grads, _ = _exchange_fsdp_shards(schedule, bucket_lookup, items, lambda p, _: p.grad) + + for p, info, shard in items: + p.grad = None + if info.owner != rank: + continue + + pieces = params.get(info.param_idx, ()) + total = sum(chunk.numel() for _, chunk in pieces) + if total != info.total: + raise RuntimeError(f"FSDP parameter assembly mismatch for param {info.param_idx}: {total} != {info.total}") + + full = shard.new_empty(info.total) + for offset, chunk in pieces: + full[offset : offset + chunk.numel()].copy_(chunk) + p.data = full.view(info.orig_shape) + + grad_pieces = grads.get(info.param_idx, ()) + if not grad_pieces: + continue + grad_total = sum(chunk.numel() for _, chunk in grad_pieces) + if grad_total != info.total: + raise RuntimeError(f"FSDP grad assembly mismatch for param {info.param_idx}: {grad_total} != {info.total}") + grad = full.new_empty(info.total, dtype=grad_pieces[0][1].dtype) + for offset, chunk in grad_pieces: + grad[offset : offset + chunk.numel()].copy_(chunk) + p.grad = grad.view(info.orig_shape) + + return _FSDPState(items, buckets) + + +def _restore_fsdp_params(state): + by_param = {info.param_idx: (p, info, shard) for p, info, shard in state.items} + for bucket in state.buckets: + payload = [] + for dst, recv_entries in enumerate(bucket.recv_entries): + for param_idx, offset, length in recv_entries: + p, info, _ = by_param[param_idx] + flat = p.data.reshape(-1) + if flat.numel() != info.total: + raise RuntimeError(f"FSDP return path expects full param {param_idx}, got {flat.numel()}") + payload.append(flat[offset : offset + length]) + send_payload = torch.cat(payload) if payload else torch.empty(0, dtype=bucket.dtype, device=bucket.device) + recv_payload = _all_to_all_variable(send_payload, bucket.send_splits, bucket.recv_splits) + + payload_offset = 0 + for send_entries in bucket.send_entries: + for _, _, _, flat, shard in send_entries: + shard.copy_(recv_payload[payload_offset : payload_offset + flat.numel()]) + payload_offset += flat.numel() + if payload_offset != recv_payload.numel(): + raise RuntimeError("FSDP return payload unpack mismatch") + + for p, _, shard in state.items: + p.data = shard + p.grad = None def _view_param(p, shape): @@ -1404,57 +1867,55 @@ def _view_param(p, shape): def _reshape_params(params, orig_shapes, gather=True): if not orig_shapes: return [], [] - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + dist_ready = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1 views, gathers = [], [] for p in params: info = orig_shapes.get(id(p)) - if info is None or p.data.shape == info.orig_shape: + if info is None: continue - if info.group is not None and gather: + if gather and dist_ready and info.owner is not None: shard = p.data - full = _reduce_gather(shard, info.offset, info.total, info.group, info.owner) - if rank == info.owner: - p.data = full.view(info.orig_shape) - if p.grad is not None: - p.grad = _reduce_gather(p.grad, info.offset, info.total, info.group, info.owner).view( - info.orig_shape - ) - else: - del full - if p.grad is not None: - _reduce_gather(p.grad, info.offset, info.total, info.group, info.owner) - p.grad = None gathers.append((p, info, shard)) + continue + + if p.data.shape == info.orig_shape: + continue + + orig, numel = info.orig_shape, p.data.numel() + if numel == info.total: + target = orig + elif numel > 0 and len(orig) >= 2: + inner = math.prod(orig[1:]) + target = (numel // inner, *orig[1:]) if numel % inner == 0 else None else: - orig, numel = info.orig_shape, p.data.numel() - if numel == info.total: - target = orig - elif numel > 0 and len(orig) >= 2: - inner = math.prod(orig[1:]) - target = (numel // inner, *orig[1:]) if numel % inner == 0 else None - else: - continue - if target is not None: - flat = p.data.shape - _view_param(p, target) - views.append((p, flat)) + continue + if target is not None: + flat = p.data.shape + _view_param(p, target) + views.append((p, flat)) + + if gathers: + gathers = _reshape_fsdp_params(gathers) return views, gathers def _restore_params(views, gathers): - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - for p, info, shard in gathers: - if rank == info.owner: - full = p.data.flatten() - else: - full = shard.new_empty(info.total) - torch.distributed.broadcast(full, src=info.owner, group=info.group) - shard.copy_(full[info.offset : info.offset + shard.numel()]) - p.data = shard - p.grad = None + if isinstance(gathers, _FSDPState): + _restore_fsdp_params(gathers) + else: + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + for p, info, shard in gathers: + if rank == info.owner: + full = p.data.flatten() + else: + full = shard.new_empty(info.total) + torch.distributed.broadcast(full, src=info.owner, group=info.group) + shard.copy_(full[info.offset : info.offset + shard.numel()]) + p.data = shard + p.grad = None for p, flat in views: _view_param(p, flat) @@ -1472,7 +1933,7 @@ def _inner_chain(state, group, update, grad, param, *fns): return update, skip_update -def chain(state: Union[callable, dict], group, grad, param, *fns): +def chain(state: list, group, grad, param, *fns): update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad] ecc = ECCConfig.from_group(group, key="param_ecc") @@ -1483,8 +1944,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns): utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad) return - states = [state(p) for p in param] - corrs = [st["param::ecc"] for st in states] + corrs = [st["param::ecc"] for st in state] with ecc.attached(param, corrs): update, skip_update = _inner_chain(state, group, update, grad, param, *fns) if not skip_update and update is not None: @@ -1500,9 +1960,14 @@ def _walk_fns(obj): stack.append(cur.fn) elif isinstance(cur, functools.partial): stack.append(cur.func) - elif isinstance(cur, Branch): + elif isinstance(cur, Parallel): for branch in cur.branches: stack.extend(branch) + elif isinstance(cur, Route): + for _, fns in cur.routes: + stack.extend(fns) + if cur.default is not None: + stack.extend(cur.default) elif isinstance(cur, _Iterable) and not isinstance(cur, (str, bytes, bytearray)): stack.extend(cur) @@ -1518,6 +1983,7 @@ def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0): for ft in _walk_fns(new_fns): if not retain or ft.transform_idx is None: ft.transform_idx, offset = offset, offset + 1 + ft._build_val_names() return new_fns @@ -1532,15 +1998,18 @@ class ChainOpt(utils.StatefulOptimizer): "eps": 1e-8, } - def __init__(self, params, defaults, foreach: bool, *fns): + def __init__(self, params, defaults, *fns): orig = defaults.pop("orig_shapes", None) self._orig_shapes = ( {k: _ShapeInfo(v) if isinstance(v, tuple) else v for k, v in orig.items()} if orig is not None else None ) base = self.global_defaults.copy() base.update({k: v for k, v in defaults.items() if v is not use_default}) - super().__init__(params, base, foreach) + super().__init__(params, base) self.fns = fns + self._eager_chain = self._run_chain + if self.compile_step: + self._run_chain = torch.compile(self._run_chain, fullgraph=True) self.register_load_state_dict_post_hook(ChainOpt._restore_ecc_dtypes) self._init_param_ecc() @@ -1631,17 +2100,28 @@ def fns(self, value): self._fns = value self._set_indices(retain=True) self._needs_gather = any(getattr(ft, "needs_full_param", False) for ft in _walk_fns(self._fns)) + self._transform_ids = frozenset( + ft.transform_idx + for ft in _walk_fns(self._fns) + if ft.transform_idx is not None and getattr(ft, "needs_init", True) + ) def _set_indices(self, retain=True): self._fns = set_indices(self.fns, retain) + def _find_val_name(self, name): + for ft in _walk_fns(self._fns): + if name in ft._val_names: + return ft._val_names[name] + raise KeyError(f"No transform stores '{name}'") + def _step(self, group): if "base_lr" not in group: group["base_lr"] = group["lr"] - if "prev_lr" in group and group["prev_lr"] != group["lr"]: + if "base_lr" in group and group["base_lr"] != group["lr"]: utils.warn_once( f"Learning rate changed between steps. This is an experimental feature and " - f"only supported with foreach=True (currently foreach={group['foreach']})." + f"only supported with multi_tensor=True (currently multi_tensor={group['multi_tensor']})." ) group["base_lr"] = group["lr"] @@ -1664,48 +2144,69 @@ def _step_inner(self, group): return p, g = zip(*vals) - for param in p: - state = self.state_(param) - if "step" in state: - step = state["step"] - elif self.compile_step: - step = utils.scalar_guard(0, param) + step = group.get("_group_step") + if step is None: + for param in group["params"]: + param_state = self.state.get(param) + if not isinstance(param_state, dict): + continue + for idx_state in param_state.values(): + if isinstance(idx_state, dict) and "step" in idx_state: + step = idx_state["step"] + break + if step is not None: + break else: step = 0 - break - - group["step"] = state["step"] = step = step + 1 - group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1) + if isinstance(step, torch.Tensor): + step = step.to(device=p[0].device, dtype=torch.int64) + else: + step = utils.scalar_guard(step, p[0]) + group["_group_step"] = group["step"] = step = step + 1 + self.state_(p[0])["step"] = step + group["prev_lr"] = group["lr"] = group["base_lr"] * step / step.clamp(min=group["warmup_steps"] + 1) - if not group["foreach"] or len(p) == 1: + if not group["multi_tensor"] or len(p) == 1: for param, grad in zip(p, g): - chain(self.state_, group, [grad], [param], *self.fns) - group["caution"] = caution + self._chain(group, [grad], [param], caution) else: - chain(self.state_, group, g, p, *self.fns) + self._chain(group, g, p, caution) group["caution"] = caution - group["lr"] = group["prev_lr"] + group["lr"] = group["base_lr"] group["step"] = None + def _run_chain(self, state, group, g, p, caution): + chain(state, group, g, p, *self.fns) + group["caution"] = caution -str_or_fn = Union[str, callable, None, Literal[use_default]] + def _needs_init(self, state): + ids = self._transform_ids + if not ids: + return False + all_initialized = set() + for st in state: + all_initialized.update(st.get("is_initialized", ())) + return not ids.issubset(all_initialized) + + def _needs_eager(self, group, state): + if self._needs_init(state): + return True + if group.get("is_preconditioning", False): + return True + if group.get("ecc") or group.get("param_ecc"): + return True + return False + + def _chain(self, group, g, p, caution): + state = [self.state_(pi) for pi in p] + fn = self._run_chain + if self.compile_step and self._needs_eager(group, state): + fn = self._eager_chain + fn(state, group, g, p, caution) -def _get_clip_fn(name: str_or_fn, default_val: str_or_fn): - name = default(name, default_val) - if callable(name): - return name - elif name not in ( - "l2_clip_", - "rmsnorm_clip_", - "trust_region_clip_", - "a_law_compress", - "mu_law_compress", - "softsign_compress", - ): - raise ValueError(f"Clipping function {name} not found") - return getattr(utils, name) +str_or_fn = Union[str, callable, None, Literal[use_default]] def default(a, b): @@ -1723,6 +2224,7 @@ def default(a, b): scale_by_laprop.get_fn(): update_by_laprop, # scale_by_adopt.get_fn(): update_by_adopt, # scale_by_ademamix.get_fn(): update_by_ademamix, # + scale_by_psgd_pro.get_fn(): update_by_psgd_pro, # } _scale_to_update_map_inv = { update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, # @@ -1734,6 +2236,7 @@ def default(a, b): update_by_laprop.get_fn(): scale_by_laprop, # update_by_adopt.get_fn(): scale_by_adopt, # update_by_ademamix.get_fn(): scale_by_ademamix, # + update_by_psgd_pro.get_fn(): scale_by_psgd_pro, # } @@ -1742,23 +2245,17 @@ class BaseOpt(ChainOpt): Base Optimizer compile_step: bool = False - Whether to change some internals to try to make the optimizer compilable - This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP) + Whether to torch.compile the optimizer step (fullgraph=True). + Initialization runs eagerly on the first step; subsequent steps are compiled. promote: bool = False - Whether to promote the gradients to fp32 before applying the optimizer - Improves update quality for low-precision parameters, but increases costs - Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime + Whether to promote the gradients to fp32 before applying the optimizer. gradient_clipping: str_or_fn = None - The function to use for clipping the incoming gradients, before any other transformations. - This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain. + Clipping function applied to incoming gradients before any other transforms. update_clipping: str_or_fn = None - The function to use for clipping the outgoing updates before applying them, after all other transformations. - This will turn off fused updates. - This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain. - + Clipping function applied to outgoing updates. Disables fused updates. """ gradient_clipping: str_or_fn = None @@ -1770,7 +2267,6 @@ def __init__( self, params, defaults, - foreach: bool = True, gradient_clipping: str_or_fn = None, update_clipping: str_or_fn = None, palm: bool = use_default, @@ -1818,57 +2314,49 @@ def __init__( if default(update_clipping, self.update_clipping) is not None: fns = fns + (apply_to_idx(update_clipping, 2),) - super().__init__(params, defaults, foreach, *fns) + super().__init__(params, defaults, *fns) class ScheduleFree(BaseOpt): def eval(self): + return self.train(False) + + def train(self, mode: bool = True): + z_key = self._find_val_name("z") for group in self.param_groups: - group["train_mode"] = train_mode = not group.get("train_mode") - beta1 = utils.get_beta1(group) - if beta1 > 0 and not train_mode: - for p in group["params"]: - state = self.state_(p) - if "z" in state: - # Set p.data to x - z = utils.promote(state["z"]) - p32 = utils.promote(p.data) - p32.lerp_(end=z, weight=1 - 1 / beta1) - utils.copy_stochastic_(p.data, p32) - - def train(self): - for group in self.param_groups: - group["train_mode"] = train_mode = not group.get("train_mode") + train_mode = group.get("train_mode", True) + if train_mode == mode: + continue + group["train_mode"] = mode beta1 = utils.get_beta1(group) - if beta1 > 0 and train_mode: - for p in group["params"]: - state = self.state_(p) - if "z" in state: - z = utils.promote(state["z"]) - p32 = utils.promote(p.data) - p32.lerp_(end=z, weight=1 - beta1) - utils.copy_stochastic_(p.data, p32) + if beta1 <= 0: + continue + weight = 1 - beta1 if mode else 1 - 1 / beta1 + for p in group["params"]: + state = self.state_(p) + if z_key in state: + z = utils.promote(state[z_key]) + p32 = utils.promote(p.data) + p32.lerp_(end=z, weight=weight) + utils.copy_stochastic_(p.data, p32) + return self class MSAM(BaseOpt): def eval(self): + return self.train(False) + + def train(self, mode: bool = True): + z_key = self._find_val_name("z") for group in self.param_groups: - group["train_mode"] = train_mode = not group.get("train_mode") - if not train_mode: - for p in group["params"]: - state = self.state_(p) - if "z" in state: - p_copy = p.data.clone() - utils.copy_stochastic_(p.data, state["z"]) - utils.copy_stochastic_(state["z"], p_copy) - - def train(self): - for group in self.param_groups: - group["train_mode"] = train_mode = not group.get("train_mode") - if train_mode: - for p in group["params"]: - state = self.state_(p) - if "z" in state: - p_copy = p.data.clone() - utils.copy_stochastic_(p.data, state["z"]) - utils.copy_stochastic_(state["z"], p_copy) + train_mode = group.get("train_mode", True) + if train_mode == mode: + continue + group["train_mode"] = mode + for p in group["params"]: + state = self.state_(p) + if z_key in state: + p_copy = p.data.clone() + utils.copy_stochastic_(p.data, state[z_key]) + utils.copy_stochastic_(state[z_key], p_copy) + return self diff --git a/heavyball/helpers.py b/heavyball/helpers.py index 23a4e77..71a2cdb 100644 --- a/heavyball/helpers.py +++ b/heavyball/helpers.py @@ -178,7 +178,6 @@ class BoTorchSampler(SimpleAPIBaseSampler): """ A significantly more efficient implementation of `BoTorchSampler` from Optuna - keeps more on the GPU / in torch The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License - The original API is kept for backward compatibility, but many arguments are ignored to improve maintainability. """ def __init__( @@ -186,18 +185,12 @@ def __init__( search_space: Optional[dict[str, BaseDistribution]] = None, *, candidates_func: Optional[Callable[..., Tensor]] = None, - constraints_func: Optional[Callable[..., Tensor]] = None, n_startup_trials: int = 10, - consider_running_trials: bool = False, independent_sampler: Optional[BaseSampler] = None, seed: int | None = None, device: torch.device | str | None = None, trial_chunks: int = 128, ): - if constraints_func is not None: - raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.") - if consider_running_trials: - raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.") if candidates_func is not None and not callable(candidates_func): raise TypeError("candidates_func must be callable.") self._candidates_func = candidates_func @@ -206,7 +199,6 @@ def __init__( self._seed = seed self.trial_chunks = trial_chunks - self._study_id: int | None = None self.search_space = {} if search_space is None else dict(search_space) if isinstance(device, str): device = torch.device(device) @@ -643,12 +635,9 @@ def __init__( search_space: dict[str, BaseDistribution], *, seed: int | None = None, - constant_liar: bool = False, independent_sampler: BaseSampler | None = None, ) -> None: super().__init__(search_space, seed) - if constant_liar: - raise NotImplementedError("constant_liar is not supported by HEBOSampler.") self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed) self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed) self._rng = np.random.default_rng(seed) @@ -716,7 +705,6 @@ def __init__( upper: np.ndarray, seed: Optional[int] = None, population_size: Optional[int] = None, - learning_rate: Optional[float] = None, last_n: int = 4096, loco_step_size: float = 0.1, device: str | None = None, @@ -733,7 +721,6 @@ def __init__( self.last_n = last_n self.batchnorm_decay = batchnorm_decay self.score_decay = score_decay - self._learning_rate = learning_rate or 1.0 / np.sqrt(n_dimension) self._mean = torch.from_numpy(mean).to(device) self._sigma = torch.from_numpy(inv_sigma).to(device) self._lower = torch.from_numpy(lower).to(device) @@ -832,20 +819,16 @@ def __init__( search_space: Dict[str, BaseDistribution], x0: Optional[Dict[str, Any]] = None, sigma0: Optional[float] = None, - lr: Optional[float] = None, n_startup_trials: int = 1, independent_sampler: Optional[BaseSampler] = None, - warn_independent_sampling: bool = True, seed: Optional[int] = None, population_size: Optional[int] = None, ) -> None: self.search_space = search_space self._x0 = x0 self._sigma0 = sigma0 - self._lr = lr self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed) self._n_startup_trials = n_startup_trials - self._warn_independent_sampling = warn_independent_sampling self._optimizer: Optional[FastINGO] = None self._seed = seed self._population_size = population_size @@ -906,7 +889,6 @@ def sample_relative( return {} if len(search_space) == 1: - self._warn_independent_sampling = False return {} trans = SearchSpaceTransform(search_space) @@ -949,7 +931,6 @@ def _init_optimizer( upper=upper_bounds, seed=self._seed, population_size=population_size, - learning_rate=self._lr, ) def sample_independent( @@ -1022,10 +1003,7 @@ def __init__( search_space: Optional[dict[str, BaseDistribution]] = None, *, seed: int | None = None, - constraints_func: Optional[Callable[..., Any]] = None, ) -> None: - if constraints_func is not None: - raise NotImplementedError("constraints_func is not supported by AutoSampler.") if samplers is None: if search_space is None: raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.") @@ -1036,7 +1014,6 @@ def __init__( self._rng = LazyRandomState(seed) self._random_sampler = RandomSampler(seed=seed) self._thread_local_sampler = ThreadLocalSampler() - self._constraints_func = constraints_func self._completed_trials = 0 self._current_index = -1 diff --git a/heavyball/utils.py b/heavyball/utils.py index f935fca..0101df7 100644 --- a/heavyball/utils.py +++ b/heavyball/utils.py @@ -7,11 +7,10 @@ import itertools import math import pickle -import random import re import string import warnings -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -107,8 +106,19 @@ def _fn(*args, **kwargs): return _fn +def decorator_no_fullgraph(func: Callable): + return decorator_knowngood(func, fullgraph=False) + + einsum_base = string.ascii_lowercase +no_compile_qr = torch.compiler.disable(torch.linalg.qr) +no_compile_eigh = torch.compiler.disable(torch.linalg.eigh) +no_compile_solve = torch.compiler.disable(torch.linalg.solve) +no_compile_svd = torch.compiler.disable(torch.linalg.svd) +no_compile_lobpcg = torch.compiler.disable(torch.lobpcg) +no_compile_solve_triangular = torch.compiler.disable(torch.linalg.solve_triangular) + @decorator_knowngood def compiled_einsum(expr, *args): @@ -657,7 +667,7 @@ def _scion_bias_rms_direction(x: Tensor, eps: float = 1e-8) -> Tensor: def _scion_spectral_direction(x: Tensor) -> Tensor: flat = x.reshape(x.shape[0], -1) - inplace_orthogonal_(flat) + flat = inplace_orthogonal_(flat) normalized = flat.reshape_as(x) in_dim = max(flat.shape[1], 1) scale = math.sqrt(x.shape[0] / in_dim) @@ -666,7 +676,7 @@ def _scion_spectral_direction(x: Tensor) -> Tensor: def _scion_spectral_conv_direction(x: Tensor) -> Tensor: flat = x.reshape(x.shape[0], -1) - inplace_orthogonal_(flat) + flat = inplace_orthogonal_(flat) normalized = flat.reshape_as(x) out_channels, in_channels = x.shape[:2] spatial = math.prod(x.shape[2:]) if x.ndim > 2 else 1 @@ -771,11 +781,11 @@ def _compilable_grafting(magnitude, direction): return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6)) -@decorator_knowngood +@decorator_no_fullgraph def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str): if not isinstance(mode, ZerothPowerMode): mode = ZerothPowerMode(mode) - if not isinstance(scale_mode, ZerothPowerMode): + if not isinstance(scale_mode, OrthoScaleMode): scale_mode = OrthoScaleMode(scale_mode) if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]: y = zeropower_via_newtonschulz5(x, 5) @@ -784,12 +794,12 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor elif mode == ZerothPowerMode.legacy_newtonschulz: y = legacy_zeropower_via_newtonschulz5(x, 5) elif mode == ZerothPowerMode.qr: - y = torch.linalg.qr(promote(x)).Q + y = no_compile_qr(promote(x)).Q elif mode == ZerothPowerMode.svd: - u, _s, vt = torch.linalg.svd(promote(x)) + u, _s, vt = no_compile_svd(promote(x)) y = u @ vt elif mode == ZerothPowerMode.legacy_svd: - u, _s, vt = torch.linalg.svd(promote(x)) + u, _s, vt = no_compile_svd(promote(x)) y = u @ vt.T else: raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}") @@ -816,7 +826,7 @@ def _compilable_scatter_set(target, source, index): target[:] = source.contiguous()[index].reshape_as(target) -# @decorator_knowngood +@decorator_no_fullgraph def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor): """ Computes the eigenbases of the preconditioner using one round of power iteration @@ -885,6 +895,126 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor copy_stochastic_(q, q_new) +def _transform_projected_state(old_qs: List[Optional[Tensor]], new_qs: List[Optional[Tensor]], *states: Tensor): + if not states: + return + + ref = states[0] + if ref is None or ref.dim() == 0: + return + + assert ref.ndim < 13, "ref.ndim must be less than 13" + in_str = einsum_base[: ref.dim()] + out_str = einsum_base[ref.dim() : 2 * ref.dim()] + + old_basis = ",".join([o + i for q, i, o in zip(old_qs, in_str, in_str.upper()) if q is not None]) + if not old_basis: + return + + new_basis = ",".join([i + o for q, i, o in zip(new_qs, in_str.upper(), out_str) if q is not None]) + out_str = "".join([o if o in new_basis else i for i, o in zip(in_str, out_str)]) + subscripts = f"{in_str},{old_basis},{new_basis}->{out_str}" + old_basis = [promote(q) for q in old_qs if q is not None] + new_basis = [promote(q) for q in new_qs if q is not None] + + for state in states: + new = compiled_einsum(subscripts, promote(state), *old_basis, *new_basis) + copy_stochastic_(state, new) + + +@decorator_no_fullgraph +def init_psgd_eigenbasis(Q: List[Tensor]): + out = [] + + for q in Q: + if q.ndim < 2: + out.append(None) + continue + + q32 = promote(q) + out.append(_stable_symmetric_basis(q32.mT @ q32, out_device=q.device, out_dtype=q.dtype)) + + return out + + +@decorator_no_fullgraph +def get_psgd_eigenbasis(Q: List[Tensor], prev: List[Optional[Tensor]]): + out = [] + + for q, old_basis in zip(Q, prev): + if q.ndim < 2: + out.append(None) + continue + if old_basis is None: + raise ValueError( + "get_psgd_eigenbasis requires a previous basis for matrix blocks; use init_psgd_eigenbasis" + ) + + q32 = promote(q) + old_basis32 = promote(old_basis) + Y = q32.mT @ (q32 @ old_basis32) + basis_raw = no_compile_qr(Y, mode="reduced").Q.to(dtype=q.dtype) + projected = q32 @ promote(basis_raw) + sort_idx = torch.argsort(compiled_einsum("ij,ij->j", projected, projected), descending=True) + basis_raw = basis_raw.index_select(1, sort_idx) + signs = compiled_einsum("ij,ij->j", old_basis32, promote(basis_raw)) + signs = torch.where(signs < 0, -torch.ones_like(signs), torch.ones_like(signs)).to(dtype=basis_raw.dtype) + basis = basis_raw * signs.view(1, -1) + out.append(basis) + + return out + + +@decorator_no_fullgraph +def update_psgd_eigenbasis(Q: List[Tensor], Q_basis: List[Tensor], *states: Tensor): + new_basis = get_psgd_eigenbasis(Q, Q_basis) + _transform_projected_state(Q_basis, new_basis, *states) + + for i, (old_basis, new_basis_i) in enumerate(zip(Q_basis, new_basis)): + if old_basis is None: # happens only if ndim < 2 + continue + copy_stochastic_(old_basis, new_basis_i) + + +def _stable_symmetric_basis( + m: Tensor, + max_eps: float = 1e-3, + min_eps: float = 1e-30, + *, + out_device=None, + out_dtype=None, +): + out_device = m.device if out_device is None else out_device + out_dtype = m.dtype if out_dtype is None else out_dtype + m = promote(m.data) + + eps = min_eps + while True: + try: + eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype) + _eigval, eigvec = no_compile_eigh(m + eps * eye) + return torch.flip(eigvec, [1]).to(device=out_device, dtype=out_dtype) + except torch.OutOfMemoryError: + if m.device.type == "cpu": + raise + if torch.cuda.is_available(): + torch.cuda.synchronize(m.device) + clean() + m = m.cpu() + except RuntimeError as e: + if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)): + torch.cuda.synchronize(m.device) + clean() + m = m.cpu() + elif m.dtype != torch.double: + m = m.double() + elif eps < max_eps: + eps = eps ** (2 / 3) + else: + raise + clean() + + def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30): """ Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. @@ -896,38 +1026,7 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30): final.append(None) continue - device, dtype = m.device, m.dtype - m = promote(m.data) - - eps = min_eps - while True: - try: - eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype) - _eigval, eigvec = torch.linalg.eigh(m + eps * eye) - eigvec = eigvec.to(device=device, dtype=dtype) - break - except torch.OutOfMemoryError: - if m.device.type == "cpu": - raise - if torch.cuda.is_available(): - torch.cuda.synchronize(m.device) - clean() - m = m.cpu() - except RuntimeError as e: - if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)): - torch.cuda.synchronize(m.device) - clean() - m = m.cpu() - elif m.dtype != torch.double: - m = m.double() - elif eps < max_eps: - eps = eps ** (2 / 3) - else: - raise - clean() - - eigvec = torch.flip(eigvec, [1]) - final.append(eigvec) + final.append(_stable_symmetric_basis(m, max_eps=max_eps, min_eps=min_eps)) return final @@ -988,6 +1087,8 @@ def scalar_guard(*args): out.append(torch.empty((), dtype=promote(ref.dtype), device=ref.device).fill_(x)) elif isinstance(x, int): out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)) + elif isinstance(x, Tensor) and x.is_floating_point() and x.ndim == 0: + out.append(x.to(dtype=promote(ref.dtype))) else: out.append(x) if len(xs) == 1: @@ -1268,6 +1369,28 @@ def hasattr_none(obj, name): return getattr(obj, name, None) is not None +def set_temporary(group: dict, tensor: Tensor, **kwargs): + if not kwargs: + return + state = group.setdefault("_tmp", {}).setdefault(id(tensor), {"tensor": tensor}) + state.update(kwargs) + + +def get_temporary(group: dict, tensor: Tensor): + tmp = group.get("_tmp") + return None if tmp is None else tmp.get(id(tensor)) + + +def take_temporary(group: dict, tensor: Tensor, *keys): + state = get_temporary(group, tensor) + if state is None: + return None if len(keys) == 1 else (None,) * len(keys) + out = tuple(state.pop(key, None) for key in keys) + if len(state) == 1: + group["_tmp"].pop(id(tensor), None) + return out[0] if len(keys) == 1 else out + + class ExactHVPFailed(ValueError): pass @@ -1291,11 +1414,11 @@ class StatefulOptimizer(torch.optim.Optimizer): compile_step: bool = False hessian_approx: bool = False precond_schedule: Union[Callable, float, None] = None - stochastic_schedule: bool | Literal[use_default] = use_default finite_differences: bool = False fallback_to_finite_differences: bool = True _fallback_enabled: bool = False hvp_interval: int = 1 # grad is faster initially, hvp later + consume_grad: bool = True _INSTANCE_ATTRS = ( "compile_step", @@ -1303,30 +1426,22 @@ class StatefulOptimizer(torch.optim.Optimizer): "fallback_to_finite_differences", "hvp_interval", "hessian_approx", + "consume_grad", ) - def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False): + def __init__(self, params, defaults, use_ema: bool = False): for attr in self._INSTANCE_ATTRS: if attr in defaults: val = defaults.pop(attr) if val is not use_default: setattr(self, attr, val) - super().__init__(params, {**defaults, "foreach": foreach}) + defaults.setdefault("multi_tensor", True) + super().__init__(params, defaults) self.use_ema = use_ema self.mapping = {} self.mapping_inverse = {} - if self.stochastic_schedule is use_default: - stochastic_schedule = None - for group in self.param_groups: - new = group.get("stochastic_schedule", stochastic_schedule) - if stochastic_schedule is not None and new != stochastic_schedule: - raise ValueError("All parameter groups must have the same stochastic_schedule.") - stochastic_schedule = new - self.stochastic_schedule = stochastic_schedule - - self.inner_group = {"stochastic_schedule": self.stochastic_schedule} - self.precond_rng = random.Random(0x12312) + self.inner_group = {} self._is_preconditioning = None if self.hessian_approx and self.compile_step: @@ -1338,22 +1453,24 @@ def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False def _store_stats(self, state_dict: dict[str, any]): state_dict["heavyball"] = { "inner_group": self.inner_group, - "precond_rng": pickle.dumps(self.precond_rng), "use_ema": self.use_ema, "ema_decay": self.ema_decay, "compile_step": self.compile_step, "hessian_approx": self.hessian_approx, "precond_schedule": pickle.dumps(self.precond_schedule), - "stochastic_schedule": self.stochastic_schedule, "fallback_to_finite_differences": self.fallback_to_finite_differences, "_fallback_enabled": self._fallback_enabled, "hvp_interval": self.hvp_interval, } + _REMOVED_STATS = frozenset({"stochastic_schedule", "precond_rng"}) + def _load_stats(self, state_dict): sd = state_dict.pop("heavyball", {}) for k, v in sd.items(): - if k in ("precond_rng", "precond_schedule"): + if k in self._REMOVED_STATS: + continue + if k in ("precond_schedule",): v = pickle.loads(v) setattr(self, k, v) @@ -1392,21 +1509,31 @@ def split_p_and_g_in_group( should_promote: bool = True, raw: bool = False, ): + tmp = group.get("_tmp") for p in group["params"]: - grad = getattr(p, "grad", None) - if grad is None and skip_none: - continue - - p.grad = None - if raw: + grad = getattr(p, "grad", None) + if grad is None and skip_none: + continue + if self.consume_grad: + p.grad = None yield p, grad continue + state = None if tmp is None else tmp.get(id(p)) + grad = None if state is None else state.pop("grad", None) + if grad is None: + grad = getattr(p, "grad", None) + if grad is None and skip_none: + continue + + if self.consume_grad: + p.grad = None + if group.get("merge_dims", False) and not p.data.is_contiguous(): for fmt in (torch.channels_last, torch.channels_last_3d): if p.data.is_contiguous(memory_format=fmt): - p._restore_memory_format = fmt + set_temporary(group, p, restore_memory_format=fmt) break p.data = p.data.contiguous() @@ -1414,20 +1541,23 @@ def split_p_and_g_in_group( for i, pv in enumerate(p_views): self.mapping_inverse[_tensor_key(pv)] = (p, i) - vector = getattr(p, "vector", None) - hessian_vector = getattr(p, "hessian_vector", None) - p.vector = None - p.hessian_vector = None - - grad, vs, hvs = [ - [None] * len(p_views) if x is None else merge_group(group, x) # - for x in (grad, vector, hessian_vector) - ] + if state is None: + vector = hessian_vector = None + else: + vector = state.pop("vector", None) + hessian_vector = state.pop("hessian_vector", None) + if len(state) == 1: + tmp.pop(id(p), None) + grad = itertools.repeat(None, len(p_views)) if grad is None else merge_group(group, grad) + vs = itertools.repeat(None, len(p_views)) if vector is None else merge_group(group, vector) + hvs = itertools.repeat(None, len(p_views)) if hessian_vector is None else merge_group(group, hessian_vector) for pv, g, v, hv in zip(p_views, grad, vs, hvs): g = promote_detach(g, should_promote) - pv.vector = promote_detach(v, should_promote) - pv.hessian_vector = promote_detach(hv, should_promote) + v = promote_detach(v, should_promote) + hv = promote_detach(hv, should_promote) + if v is not None or hv is not None: + set_temporary(group, pv, vector=v, hessian_vector=hv) yield pv, g def state_size(self) -> int: @@ -1501,10 +1631,10 @@ def _finite_differences_hvp(self, closure): for group in self.param_groups: for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True): grads.append(g) - p.vector = torch.randn_like(p) - p.orig = p.data.clone() + vector = torch.randn_like(p) + set_temporary(group, p, vector=vector, orig=p.data.clone()) # scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161 - stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5) + stochastic_add_(p.data, vector, torch.finfo(torch.float32).eps ** 0.5) with torch.enable_grad(): closure() @@ -1513,22 +1643,22 @@ def _finite_differences_hvp(self, closure): # this costs more memory, but the imprecision seems too severe to use the other method for group in self.param_groups: for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True): - p.grad = grads.pop(0) - stochastic_add_divide_(g, p.grad, -1, torch.finfo(p.dtype).eps ** 0.5) - p.hessian_vector = g - p.data.copy_(p.orig) - del p.orig + grad = grads.pop(0) + set_temporary(group, p, grad=grad, hessian_vector=g) + stochastic_add_divide_(g, grad, -1, torch.finfo(torch.float32).eps ** 0.5) + p.data.copy_(take_temporary(group, p, "orig")) return loss def _double_backward_hvp(self, closure): with torch.enable_grad(), patch_backward(): loss = closure() - params, grads = [], [] + params, grads, groups = [], [], [] for group in self.param_groups: for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True): params.append(p) grads.append(g) + groups.append(group) if not params: raise ValueError("No parameter has gradients") @@ -1541,10 +1671,8 @@ def _double_backward_hvp(self, closure): raise ExactHVPFailed(str(e.args)) unused = [] - for p, g, v, hv in zip(params, grads, vs, hvs): - p.hessian_vector = detach(hv) - p.grad = detach(g) - p.vector = detach(v) + for group, p, g, v, hv in zip(groups, params, grads, vs, hvs): + set_temporary(group, p, grad=detach(g), vector=detach(v), hessian_vector=detach(hv)) if hv is None: unused.append(list(p.shape)) @@ -1597,33 +1725,39 @@ def _handle_closure(self, closure): self._fallback_enabled = True return self._handle_closure(closure) + def _cleanup_temporary_tensors(self): + for group in self.param_groups: + tmp = group.pop("_tmp", None) + if tmp is None: + continue + for state in tmp.values(): + fmt = state.get("restore_memory_format") + if fmt is None: + continue + tensor = state["tensor"] + self.mapping_inverse.pop(_tensor_key(tensor), None) + tensor.data = tensor.data.to(memory_format=fmt) + self.mapping_inverse[_tensor_key(tensor)] = (tensor, 0) + def step(self, closure: Optional[Callable] = None): if self.precond_schedule is None: self._is_preconditioning = False else: - self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng) + self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule) loss = self._handle_closure(closure) if self.use_ema: self.ema_update() # we assume that parameters are constant and that there are no excessive recompiles with torch.no_grad(), torch._dynamo.utils.disable_cache_limit(): - for group in self.param_groups: - if "param_count" not in group: - group["param_count"] = sum(p.numel() for p in group["params"]) - group["is_preconditioning"] = self._is_preconditioning - self._step(group) - for real, views in self.mapping.items(): - fmt = getattr(real, "_restore_memory_format", None) - if fmt is not None: - del self.mapping_inverse[_tensor_key(real)] - real.data = real.data.to(memory_format=fmt) - self.mapping_inverse[_tensor_key(real)] = (real, 0) - del real._restore_memory_format - for tensor in (real, *views): - for key in ("grad", "vector", "hessian_vector", "orig"): - if hasattr(tensor, key): - setattr(tensor, key, None) + try: + for group in self.param_groups: + if "param_count" not in group: + group["param_count"] = sum(p.numel() for p in group["params"]) + group["is_preconditioning"] = self._is_preconditioning + self._step(group) + finally: + self._cleanup_temporary_tensors() return loss @@ -2114,17 +2248,17 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]): def stochastic_round_(ref: Tensor, source: Tensor | None = None): if source is None: source = ref - if ref.dtype != torch.bfloat16: - return source.to(ref.dtype) - if source.dtype == torch.bfloat16: - return source - if source.dtype in (torch.float16, torch.float32, torch.float64): - source = source.to(torch.float32).view(dtype=torch.int32) - noise = sum(torch.randint_like(source, low=0, high=(1 << 16)) for _ in range(dither_steps)) - noise = noise + source - (dither_steps - 1) * (1 << 15) # center | x - (N-1)*delta/2 - noise = noise.bitwise_and(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits - return noise.view(dtype=torch.float32).bfloat16() - return source.to(ref.dtype) + dtype = torch.bfloat16 + else: + dtype = ref.dtype + if dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64): + return source.to(dtype) + + source = source.to(torch.float32).view(dtype=torch.int32) + noise = sum(torch.randint_like(source, low=0, high=(1 << 16)) for _ in range(dither_steps)) + noise = noise + source - (dither_steps - 1) * (1 << 15) # center | x - (N-1)*delta/2 + noise = noise.bitwise_and(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits + return noise.view(dtype=torch.float32).bfloat16() @decorator_knowngood @@ -2151,7 +2285,7 @@ def _compilable_update_( caution: bool, g: List[Optional[Tensor]], ): - for i, (u_, g_, p_) in enumerate(zip(u, g, p)): # lr is data-dependent -> can't compile a foreach + for i, (u_, g_, p_) in enumerate(zip(u, g, p)): # lr is data-dependent -> can't compile a multi-tensor op u_ = promote(u_.view_as(p_)) p32_ = promote(p_) if caution: @@ -2175,9 +2309,10 @@ def update_param_( _compilable_update_(param, update, decay, lr, caution, grad) -def precond_schedule(step, precond_scheduler): - precond_prob = max(step, 1) ** precond_scheduler[0] - precond_prob = math.log10(precond_prob) +@decorator_knowngood +def precond_schedule(step: Tensor, precond_scheduler): + precond_prob = step.clamp(min=1) ** precond_scheduler[0] + precond_prob = torch.log10(precond_prob) precond_prob = precond_prob ** precond_scheduler[1] + 1 return 1 / precond_prob @@ -2363,7 +2498,7 @@ def init_Q_exprs( @decorator_knowngood def psgd_balance_Q(Q): - norms = [promote(q.norm(float("inf"))).log() for q in Q] + norms = [promote(q.abs().max()).log() for q in Q] geometric_mean = sum([n for n in norms]) / len(Q) for q, n in zip(Q, norms): q *= (geometric_mean - n).exp() @@ -2524,10 +2659,10 @@ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor): @decorator_knowngood -def dampen_grad(g: Tensor, damp: float = 2**-13): - # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50 +def dampen_grad(g: Tensor, damp: float = 1e-9): v = torch.randn_like(g) - return v, g + damp * g.abs().mean() * v + damping = damp + torch.finfo(torch.float32).eps * g.abs() + return v, g + damping * v @decorator_knowngood @@ -2620,7 +2755,7 @@ def multi_flatten(*xs: Tuple[List[Tensor], int]): @decorator_knowngood -def dampen_multiple(g: List[Tensor], damp: float = 2**-13): +def dampen_multiple(g: List[Tensor], damp: float = 1e-9): vs = [] gs = [] for g_ in g: @@ -2631,8 +2766,7 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13): def casted_einsum(expr: str, *args: Tensor) -> Tensor: - md = min_dtype(args) - return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype) + return compiled_einsum(expr, *[promote(a) for a in args]).to(args[-1].dtype) @decorator_knowngood @@ -2671,14 +2805,13 @@ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "v conjB = torch.randn_like(G) exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same A = casted_einsum(exprA, *Q, G) - solve = torch.compiler.disable(torch.linalg.solve_triangular) transposed_shape = original_shape = conjB.shape prev_i = -1 qs, conjB = _psgd_calc_scalars_(Q, conjB) for i, tri_q in qs: conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i) prev_i = i - conjB = solve(tri_q, conjB, upper=True, left=False) + conjB = no_compile_solve_triangular(tri_q, conjB, upper=True, left=False) conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1) return A, conjB @@ -2687,7 +2820,7 @@ def max_singular_value_exact(A, use_lobpcg: bool = False): try: if use_lobpcg: A = A @ A.T - eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True) + eigval, _ = no_compile_lobpcg(A, k=1, largest=True) return eigval[0].sqrt() else: return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2) @@ -2727,7 +2860,7 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None): Adapted from @evanatyourservice """ if max_abs is None: - max_abs = A.norm(float("inf")).clamp(min=1e-8) + max_abs = A.abs().max().clamp(min=1e-8) # cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size @@ -2768,6 +2901,45 @@ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False, return max_singular_value_power_iter(A, None, iterations=power_iter) +@decorator_knowngood +def max_eigenvalue_spd(A_outer: Tensor, power_iter: int = 4) -> Tensor: + """Power iteration for the largest eigenvalue of a symmetric positive (semi)definite matrix. + Exploits A = A^T: A^T A = A^2, so v -> A^T(Av) = v -> A(Av), saving a transpose. + Uses x @ A.mT (gemm transB=true) for faster BLAS dispatch than A.mv(x).""" + if A_outer.ndim < 2: + return A_outer.max() + x_norm, max_idx = A_outer.norm(dim=1).max(dim=0) + x_norm = promote(x_norm) + + def _inner(): + x = A_outer.index_select(0, max_idx).flatten().contiguous() + A = promote(A_outer) / x_norm + x = x / x_norm + + def _mv(x): + return promote((x @ A.mT) @ A.mT) + + for _ in range(power_iter): + x = F.normalize(_mv(x), dim=0) + return (x @ _mv(x)).sqrt() * x_norm + + return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone()).squeeze() + + +@decorator_knowngood +def procrustes_step(Q: Tensor, max_step_size: float = 1 / 8) -> None: + Q_ = promote(Q) + R = (Q_.T - Q_).contiguous() + R_norm = max_singular_value(R, power_iter=2) + torch.finfo(R.dtype).smallest_normal + R = R / R_norm + RQ = R @ Q_ + RRQ = R @ RQ + tr_RQ = RQ.diagonal().sum() + tr_RRQ = RRQ.diagonal().sum() + a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, max=max_step_size), max_step_size) + copy_stochastic_(Q, Q_ + a * (RQ + 0.5 * a * RRQ)) + + @decorator_knowngood def clamped_max_singular_value( A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16 @@ -2821,10 +2993,10 @@ def _approx(): @decorator_knowngood -def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False): +def _balance_to_triu(Q: "TriuOrLine"): if isinstance(Q[0], tuple): psgd_balance_Q([o[1] for o in Q]) - return line_to_triu(Q, symmetric_output) + return line_to_triu(Q) psgd_balance_Q(Q) return Q @@ -2927,33 +3099,20 @@ def _chebychef_coeff(degree: int, device, eps: float = 1e-8): return coeff0.float(), coeffs.float() -@decorator_knowngood -def _psgd_default_preconditioner_grad( - terms: List[Tuple[Tensor, Tensor]], - Q: List[Tensor], -) -> List[Tensor]: - out = [] - for q, (x, y) in zip(Q, terms): - x = promote(x) - y = promote(y) - update = x - y - if q.ndim < 2: - update = promote(q) * update - else: - update = (promote(q) @ update).triu() - out.append(update) - return out +def _update_lb(ell: Tensor, lb_state: Tensor, beta: Tensor) -> Tensor: + ell = promote(ell) + ell = ell.maximum(promote(lb_state) + (ell - promote(lb_state)) * (1 - beta)) + copy_stochastic_(lb_state, ell) + return ell -@decorator +@decorator_no_fullgraph def psgd_update_precond( G: Tensor, precond_lr: float, oq: "TriuOrLine", store_triu_as_line: bool, - velocity: Optional[List[Tensor]], beta2: float, - ortho_method: Optional[str], V: Tensor, running_lower_bound: List[Tensor], lower_bount_beta: float, @@ -2965,15 +3124,61 @@ def psgd_update_precond( precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G) A, conjB = psgd_calc_A_and_conjB(G, Q, V) - terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs] - del A, conjB, V - updates = _psgd_default_preconditioner_grad(terms, Q) - _psgd_precond_update_( - updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter - ) + del V + + for oq_i, q, exprG, lb_state in zip(oq, Q, exprGs, running_lower_bound): + term1 = promote(compiled_einsum(exprG, A, A)) + term2 = promote(compiled_einsum(exprG, conjB, conjB)) + + if q.ndim < 2: + ell = _update_lb((term1 + term2).max(), lb_state, lower_bount_beta) + update = promote(q) * (term1 - term2) + else: + ell = _update_lb(max_eigenvalue_spd(term1 + term2, power_iter=power_iter), lb_state, lower_bount_beta) + update = (term1 - term2).triu() @ promote(q) + if store_triu_as_line: + update = triu_to_line([update])[0][1] + + real_oq = oq_i[1] if isinstance(oq_i, tuple) else oq_i + copy_stochastic_(real_oq, promote(real_oq) - update / ell * precond_lr) return None +@decorator_knowngood +def psgd_pro_update_precond( + G: Tensor, + precond_lr: float, + Q: List[Tensor], + running_lower_bound: List[Tensor], + lower_bount_beta: float, + power_iter: int, + dampening: float, +) -> None: + """Update Kronecker product preconditioner Q with Q0.5EQ1.5 (PRO) method.""" + psgd_balance_Q(Q) + exprGs = calcG_expr(ndim_tuple(Q), G.ndim) + precond_lr, lower_bount_beta = scalar_guard(precond_lr, lower_bount_beta, G) + + damping = dampening + torch.finfo(torch.float32).eps * G.abs() + Pg = psgd_precond_grad(G + damping * torch.randn_like(G), Q) + + total_numel = G.numel() + for q, exprG, lb_state in zip(Q, exprGs, running_lower_bound): + term1 = compiled_einsum(exprG, Pg, Pg) + q_ = promote(q) + + if q.ndim < 2: + term2 = total_numel / max(1, q.numel()) + ell = _update_lb(term1.max() + term2, lb_state, lower_bount_beta) + copy_stochastic_(q, q_ - q_ * (term1 - term2) / ell * precond_lr) + else: + term2 = total_numel / q.shape[0] + ell = _update_lb(max_eigenvalue_spd(term1, power_iter=power_iter) + term2, lb_state, lower_bount_beta) + copy_stochastic_(q, q_ - (term1 @ q_ - term2 * q_) / ell * precond_lr) + procrustes_step(q) + del Pg + + @decorator_knowngood def bf16_matmul(x: Tensor, y: Tensor): return (promote(x) @ promote(y)).to(x.dtype) @@ -3037,8 +3242,8 @@ def eigvecs_product_rank1( using the Householder reflector with first column v. Never materializes V. Args: - G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis. - v: shape (d,) — current unit direction (top eigenvector of P). + G: shape (..., d) - gradient row(s) you want to rotate into eigenbasis. + v: shape (d,) - current unit direction (top eigenvector of P). w: optional Householder vector w; pass to reuse across calls. Returns: @@ -3087,7 +3292,7 @@ def _psgd_precond_update_( q = promote(oq) if update.ndim < 2: - lb = update.norm(float("inf")) + lb = update.abs().max() else: lb = max_singular_value(update, power_iter=power_iter) update = promote(update) @@ -3101,65 +3306,7 @@ def _psgd_precond_update_( @decorator_knowngood -def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int): - """ - I: Identity - U: Update / gg / target - Q: q, preconditioner - scale: scalar scale - --- - U = T * scale - I - F = I - U # = 2I - U * scale - O = F @ Q @ F - Q - """ - out = [] - for gg, q in zip(GG, Q): - if gg.ndim < 2: - scale = max(1, gg.numel()) / numel - target = promote(gg) - update = target * scale - 1 - out.append(q - (1 - update) * q * (1 - update)) - else: - scale = gg.size(0) / numel - gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale - update = q - casted_einsum("ab,cd,bc", gg, gg, q) - out.append(update + update.T) # make matrix symmetric - return out - - @decorator -def inverse_free_psgd_update_precond( - G: Tensor, - precond_lr: float, - oq: List[Tensor], - store_triu_as_line: bool, - velocity: Optional[List[Tensor]], - beta2: float, - ortho_method: Optional[str], - V: None, - running_lower_bound: List[Tensor], - lower_bount_beta: float, - power_iter: int, -) -> Tensor: - """Update Kronecker product preconditioner Q with pair (V, G).""" - assert V is None - assert ortho_method is None - assert velocity is None - del V, ortho_method, velocity - - Q = _balance_to_triu(oq, True) - precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G) - exprGs = calcG_expr(ndim_tuple(Q), G.ndim) - - G = psgd_precond_grad(G, Q) - terms = [compiled_einsum(exprG, G, G) for exprG in exprGs] - matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel()) - _psgd_precond_update_( - matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter - ) - return G - - @decorator_knowngood def _clip(x, norm, clip_at, eps=1e-8): x32 = promote(x) @@ -3420,15 +3567,13 @@ def triu_to_line(Q_list: List[Tensor]): @decorator_knowngood -def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False): +def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]): new = [] for shape, q in Q_list: if shape is not None: x, y = torch.triu_indices(*shape, device=q.device) q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype) q_mat[x, y] = q - if symmetric_output: - q_mat[y, x] = q q = q_mat new.append(q) return new @@ -3443,14 +3588,10 @@ def warn_once(msg): _warned.add(msg) -def psgd_should_update( - group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob" -): +def psgd_should_update(group, prob: Union[float, callable], name: str = "cumulative_prob"): group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1 if not isinstance(prob, float): prob = prob(group[f"{name}_prob_step"]) - if group["stochastic_schedule"]: - return rng.random() < prob cumulative_prob = group.get(name, 0) group[name] = cumulative_prob + prob return int(group[name]) > int(cumulative_prob) @@ -3471,18 +3612,13 @@ def precond_grad_cached_( cached_q: List[Tensor], caution: bool = False, grad: Optional[Tensor] = None, - cast: bool = True, ): if caution: ea = _compilable_cautioning(grad, ea) - md = min_dtype(list(cached_q) + [ea]) - args = [q.to(md) for q in cached_q] - args = args + [ea.to(md)] + args = [promote(q) for q in cached_q] + args = args + [promote(ea)] expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim) - new = compiled_einsum(expr, *args) - if cast: - return new.to(ea.dtype) - return new + return compiled_einsum(expr, *args) TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]] @@ -3490,7 +3626,7 @@ def precond_grad_cached_( @decorator_knowngood def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]): - precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False) + precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad) update_param_(param, precond, lr, decay, caution=False) @@ -3517,17 +3653,14 @@ def psgd_precond_grad( caution: bool = False, grad: Optional[Tensor] = None, store_triu_as_line: bool = False, - symmetric_output: bool = False, ): if caution: ea = _compilable_cautioning(grad, ea) if store_triu_as_line: - preconds = line_to_triu(preconds, symmetric_output) - md = min_dtype(list(preconds) + [ea]) - args = [q.to(md) for q in preconds] + preconds = line_to_triu(preconds) + args = [promote(q) for q in preconds] expr = precond_grad_expr(ndim_tuple(args), ea.ndim) - new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md)) - return new.to(ea.dtype) + return compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea)) @decorator_knowngood @@ -3540,7 +3673,6 @@ def _compilable_fused_psgd_precond_grad( caution, preconds: TriuOrLine, store_triu_as_line: bool = False, - symmetric_output: bool = False, ): precond = psgd_precond_grad( ea, @@ -3548,7 +3680,6 @@ def _compilable_fused_psgd_precond_grad( caution=caution, grad=grad, store_triu_as_line=store_triu_as_line, - symmetric_output=symmetric_output, ) update_param_(param, precond, lr, decay, caution=False, grad=grad) @@ -3562,12 +3693,9 @@ def fused_psgd_precond_grad( caution, preconds: TriuOrLine, store_triu_as_line: bool = False, - symmetric_output: bool = False, ): lr, decay = scalar_guard(lr, decay, param[0]) - _compilable_fused_psgd_precond_grad( - ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output - ) + _compilable_fused_psgd_precond_grad(ea, param, lr, grad, decay, caution, preconds, store_triu_as_line) @decorator_knowngood @@ -3619,6 +3747,41 @@ def caution(g, update): return _compilable_cautioning(g, update) +@decorator_knowngood +def _compilable_hyperball_( + p: List[Tensor], + u: List[Tensor], + init_norm: List[Tensor], + lr: Tensor, + decay: float, + caution: bool, + g: List[Tensor], +): + for op, u_, n_, g_ in zip(p, u, init_norm, g): + u_ = promote(u_.view_as(op)) + p_ = promote(op) + if decay != 0: + u_ = u_ + p_ * decay + if caution: + u_ = _compilable_cautioning(promote(g_), u_) + u_norm = u_.norm() + u_norm = u_norm.clamp(min=1e-12) + u_ = u_ / u_norm + p_ = p_ - lr * u_ * n_ + p_norm = p_.norm() + p_norm = p_norm.clamp(min=1e-12) + p_ = p_ * (n_ / p_norm) + copy_stochastic_(op, p_) + + +def hyperball_step_(param, update, init_norm, lr, decay, caution, grad): + param, update, init_norm, grad = list_guard(param, update, init_norm, grad) + lr = scalar_guard(lr, param[0]) + if not caution: + grad = [None] * len(param) + _compilable_hyperball_(param, update, init_norm, lr, decay, caution, grad) + + def _inner_precond_update_prob_schedule( n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000 ): diff --git a/interactive/playground.py b/interactive/playground.py deleted file mode 100644 index bc80f62..0000000 --- a/interactive/playground.py +++ /dev/null @@ -1,1596 +0,0 @@ -import functools -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple - -import gradio as gr -import numpy as np -import plotly.graph_objects as go -import torch -import torch.nn as nn -from plotly.subplots import make_subplots -from sklearn.decomposition import PCA - -import heavyball -import heavyball.chainable as C - -# TensorFlow Playground inspired colors -COLORS = { - "bg": "#ffffff", - "surface": "#f7f7f7", - "primary": "#ff6f00", # Orange - "secondary": "#0d47a1", # Blue - "positive": "#4caf50", # Green - "negative": "#f44336", # Red - "text": "#212121", - "text_light": "#757575", - "border": "#e0e0e0", # Component categories - "gradient": "#9c27b0", # Purple - Gradient input - "momentum": "#2196f3", # Blue - Momentum transforms - "scaling": "#ff5722", # Deep Orange - Scaling transforms - "regularization": "#009688", # Teal - Regularization - "normalization": "#795548", # Brown - Normalization - "update": "#4caf50", # Green - Update rules -} - - -# Problem base class -class Problem(ABC): - """Base class for optimization problems""" - - @property - @abstractmethod - def dim(self) -> int: - """Dimension of the parameter space""" - pass - - @abstractmethod - def init(self) -> np.ndarray: - """Initial parameters""" - pass - - @abstractmethod - def loss(self, x: torch.Tensor) -> torch.Tensor: - """Compute loss given parameters""" - pass - - def bounds(self) -> Optional[List[Tuple[float, float]]]: - """Parameter bounds for visualization (None if unbounded)""" - return None - - def optimum(self) -> Optional[np.ndarray]: - """Known optimal point if available""" - return None - - -class Function2D(Problem): - """Wrapper for 2D test functions""" - - def __init__(self, func, bounds, init, optimal=None): - self.func = func - self._bounds = bounds - self._init = init - self._optimal = optimal - - @property - def dim(self) -> int: - return 2 - - def init(self) -> np.ndarray: - return np.array(self._init) - - def loss(self, x: torch.Tensor) -> torch.Tensor: - if x.dim() == 1: - return self.func(x) - else: - # Handle batch of parameters - return self.func(x[0]) - - def bounds(self): - return self._bounds - - def optimum(self): - return np.array(self._optimal) if self._optimal else None - - -class MLPProblem(Problem): - """Train a small MLP to predict |x|""" - - def _data(self, n_samples: int = 128, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - def __init__(self, hidden_size=4, n_samples=128): - self.model = nn.Sequential( - nn.Linear(1, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, 1), - ) - self.x_data, self.y_data = self._data(n_samples) - self._dim = sum(p.numel() for p in self.model.parameters()) - - @property - def dim(self) -> int: - return self._dim - - def init(self) -> np.ndarray: - # Initialize model and return flattened parameters - torch.manual_seed(42) - for p in self.model.parameters(): - if len(p.shape) >= 2: - nn.init.xavier_uniform_(p) - else: - nn.init.zeros_(p) - return self._flatten_params() - - def loss(self, x: torch.Tensor) -> torch.Tensor: - # Instead of modifying model parameters, we'll use functional approach - # Split x into weight and bias tensors - idx = 0 - params = [] - for p in self.model.parameters(): - numel = p.numel() - param = x[idx : idx + numel].view(p.shape) - params.append(param) - idx += numel - - # Manually compute forward pass using the parameters from x - # For a 2-layer MLP: Linear -> ReLU -> Linear - w1, b1, w2, b2, w3, b3 = params - - h = torch.nn.functional.linear(self.x_data, w1, b1) - h = torch.nn.functional.relu(h) - pred = torch.nn.functional.linear(h, w2, b2) - h = torch.nn.functional.relu(h) - pred = torch.nn.functional.linear(h, w3, b3) - - return nn.functional.mse_loss(pred, self.y_data) - - def _flatten_params(self) -> np.ndarray: - """Flatten model parameters to vector""" - return torch.cat([p.data.view(-1) for p in self.model.parameters()]).numpy() - - def _unflatten_params(self, x: torch.Tensor): - """Set model parameters from flat vector (used for evaluation, not training)""" - idx = 0 - for p in self.model.parameters(): - numel = p.numel() - p.data.copy_(x[idx : idx + numel].view(p.shape)) - idx += numel - - -class AbsMLP(MLPProblem): - def _data(self, n_samples: int = 128, seed: int = 42): - x = torch.rand((n_samples, 1)) * 4 - 2 - return x, x.abs() - - -class SineMLP(MLPProblem): - def _data(self, n_samples: int = 128, seed: int = 42): - x = torch.rand((n_samples, 1)) * 4 - 2 - return x, x.sin() - - -class ExpMLP(MLPProblem): - def _data(self, n_samples: int = 128, seed: int = 42): - x = torch.rand((n_samples, 1)) * 4 - 2 - return x, x.exp() - - -class QuadraticBowl(Problem): - """N-dimensional quadratic bowl""" - - def __init__(self, dim=10): - self._dim = dim - self.center = np.random.randn(dim) * 2 - - @property - def dim(self) -> int: - return self._dim - - def init(self) -> np.ndarray: - return np.random.randn(self._dim) * 3 - - def loss(self, x: torch.Tensor) -> torch.Tensor: - center = torch.tensor(self.center, dtype=x.dtype, device=x.device) - return torch.sum((x - center) ** 2) - - def optimum(self): - return self.center - - -class StyblinskiTang(Problem): - """Styblinski-Tang function (separable, multimodal)""" - - def __init__(self, dim=4): - self._dim = dim - - @property - def dim(self) -> int: - return self._dim - - def init(self) -> np.ndarray: - return np.random.uniform(-5, 5, self._dim) - - def loss(self, x: torch.Tensor) -> torch.Tensor: - return 0.5 * torch.sum(x**4 - 16 * x**2 + 5 * x) - - def bounds(self): - return [(-5, 5)] * self._dim - - def optimum(self): - # Global minimum at x_i = -2.903534 for all i - return np.full(self._dim, -2.903534) - - -# Test functions -PROBLEMS = { # 2D Problems - "Simple Bowl (2D)": Function2D( - func=lambda x: (x[0] - 1) ** 2 + (x[1] - 2) ** 2, - bounds=[(-3, 5), (-2, 6)], - init=[-1.0, 0.0], - optimal=[1.0, 2.0], - ), - "Ravine (2D)": Function2D( - func=lambda x: (x[0] - 1) ** 2 + (x[1] - 2) ** 100, - bounds=[(-3, 5), (-2, 6)], - init=[-1.0, 1.0], - optimal=[1.0, 2.0], - ), - "Rosenbrock (2D)": Function2D( - func=lambda x: (1 - x[0]) ** 2 + 100 * (x[1] - x[0] ** 2) ** 2, - bounds=[(-2, 2), (-1, 3)], - init=[-1.0, 1.0], - optimal=[1.0, 1.0], - ), - "Himmelblau (2D)": Function2D( - func=lambda x: (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2, - bounds=[(-5, 5), (-5, 5)], - init=[0.0, 0.0], - optimal=[3.0, 2.0], - ), - "Beale (2D)": Function2D( - func=lambda x: ( - (1.5 - x[0] + x[0] * x[1]) ** 2 - + (2.25 - x[0] + x[0] * x[1] ** 2) ** 2 - + (2.625 - x[0] + x[0] * x[1] ** 3) ** 2 - ), - bounds=[(-4.5, 4.5), (-4.5, 4.5)], - init=[1.0, 1.0], - optimal=[3.0, 0.5], - ), # High-dimensional Problems - "MLP abs(x) (13D)": AbsMLP(hidden_size=4, n_samples=128), - "MLP sin(x) (13D)": SineMLP(hidden_size=4, n_samples=128), - "MLP exp(x) (13D)": ExpMLP(hidden_size=4, n_samples=128), - "Quadratic Bowl (10D)": QuadraticBowl(dim=10), - "Styblinski-Tang (4D)": StyblinskiTang(dim=4), - "Styblinski-Tang (8D)": StyblinskiTang(dim=8), -} - -# Chainable optimizer components -OPTIMIZER_BLOCKS = { - "gradient": { - "name": "🎯 Gradient Input", - "color": COLORS["gradient"], - "components": [ - { - "id": "gradient_input", - "name": "Raw Gradient", - "icon": "∇", - "description": "Start of the pipeline - raw gradients from backprop", - "func": None, # This is just a placeholder - "params": {}, - } - ], - }, - "momentum": { - "name": "⚡ Momentum", - "color": COLORS["momentum"], - "components": [ - { - "id": "heavyball", - "name": "Heavy Ball", - "icon": "🏐", - "description": "Classic momentum: v = βv - g", - "func": C.heavyball_momentum, - "params": {"beta": 0.9}, - }, - { - "id": "nesterov", - "name": "Nesterov", - "icon": "🚀", - "description": "Look-ahead momentum", - "func": C.nesterov_momentum, - "params": {"beta": 0.9}, - }, - { - "id": "nesterov_ema", - "name": "Nesterov EMA", - "icon": "📈", - "description": "Exponential moving average variant", - "func": C.nesterov_ema, - "params": {"beta": 0.9}, - }, - { - "id": "basic_ema", - "name": "Basic EMA", - "icon": "📉", - "description": "Simple exponential moving average", - "func": C.exp_avg, - "params": {"betas": (0.9, 0.999)}, - }, - ], - }, - "scaling": { - "name": "📊 Adaptive Scaling", - "color": COLORS["scaling"], - "components": [ - { - "id": "adam_scale", - "name": "Adam Scaling", - "icon": "👑", - "description": "Adaptive per-parameter learning rates", - "func": C.scale_by_adam, - "params": {"betas": (0.9, 0.999), "eps": 1e-8}, - }, - { - "id": "rmsprop_scale", - "name": "RMSprop Scaling", - "icon": "📉", - "description": "Root mean square normalization", - "func": C.scale_by_exp_avg_sq, - "params": {"beta2": 0.999, "eps": 1e-8}, - }, - { - "id": "adagrad_scale", - "name": "AdaGrad Scaling", - "icon": "📐", - "description": "Accumulate all past gradients", - "func": C.scale_by_exp_avg_sq, - "params": {"beta2": 1.0, "eps": 1e-8}, - }, - ], - }, - "regularization": { - "name": "⚖️ Regularization", - "color": COLORS["regularization"], - "components": [ - { - "id": "weight_decay", - "name": "Weight Decay", - "icon": "🪶", - "description": "L2 regularization (AdamW style)", - "func": C.weight_decay_to_ema, - "params": {"weight_decay_to_ema": 0.01, "ema_beta": 0.999}, - }, - { - "id": "weight_decay_init", - "name": "Decay to Init", - "icon": "🎯", - "description": "Pull weights toward initialization", - "func": C.weight_decay_to_init, - "params": {"weight_decay_to_init": 0.01}, - }, - { - "id": "l1_weight_decay", - "name": "L1 Weight Decay", - "icon": "⚡", - "description": "L1 regularization to EMA", - "func": C.l1_weight_decay_to_ema, - "params": {"weight_decay_to_ema": 0.01, "ema_beta": 0.999}, - }, - ], - }, - "normalization": { - "name": "🔧 Gradient Processing", - "color": COLORS["normalization"], - "components": [ - { - "id": "grad_clip", - "name": "Gradient Clipping", - "icon": "✂️", - "description": "Clip gradient norm", - "func": functools.partial(C.global_clip, clip_fn=heavyball.utils.l2_clip_), - "params": {"max_norm": 1.0}, - }, - { - "id": "sign", - "name": "Sign SGD", - "icon": "±", - "description": "Use only gradient signs", - "func": C.sign, - "params": {"graft": True}, - }, - { - "id": "orthogonalize", - "name": "Orthogonalize", - "icon": "⊥", - "description": "Orthogonalize gradient to parameter", - "func": C.orthogonalize_grad_to_param, - "params": {"eps": 1e-8}, - }, - { - "id": "orthogonalize_update", - "name": "Orthogonalize Update", - "icon": "⊗", - "description": "Orthogonalize the update itself", - "func": C.orthogonalize_update, - "params": {}, - }, - ], - }, - "advanced_scaling": { - "name": "🚀 Advanced Optimizers", - "color": COLORS["scaling"], - "components": [ - { - "id": "laprop", - "name": "Laprop", - "icon": "🌊", - "description": "Layerwise adaptive propagation", - "func": C.scale_by_laprop, - "params": {"betas": (0.9, 0.999), "eps": 1e-8}, - }, - { - "id": "adopt", - "name": "ADOPT", - "icon": "🎯", - "description": "Adaptive gradient methods", - "func": C.scale_by_adopt, - "params": {"betas": (0.9, 0.9999), "eps": 1e-12}, - }, - ], - }, - "preconditioning": { - "name": "🔮 Preconditioning", - "color": COLORS["gradient"], - "components": [ - { - "id": "soap", - "name": "SOAP", - "icon": "🧼", - "description": "Shampoo-based preconditioning", - "func": C.scale_by_soap, - "params": { - "shampoo_beta": 0.99, - "max_precond_dim": 10000, - "precondition_1d": False, - "is_preconditioning": True, - "betas": (0.9, 0.999), - "eps": 1e-8, - }, - }, - { - "id": "psgd", - "name": "PSGD", - "icon": "🎲", - "description": "Preconditioned SGD", - "func": functools.partial(C.scale_by_psgd, cached=False), - "params": { - "precond_lr": 0.1, - "max_size_triangular": 1024, - "precondition_frequency": 10, - "adaptive": False, - "store_triu_as_line": True, - "q_dtype": "float32", - "inverse_free": False, - "precond_init_scale": 1.0, - "precond_init_scale_scale": 0.0, - "precond_init_scale_power": 1.0, - "min_ndim_triangular": 2, - "memory_save_mode": None, - "dampening": 1.0, - "is_preconditioning": True, - "betas": (0.9, 0.999), - "eps": 1e-8, - "ortho_method": "qr", - "lower_bound_beta": 0.999, - "precond_update_power_iterations": 1, - }, - }, - { - "id": "psgd_lra", - "name": "PSGD LRA", - "icon": "📐", - "description": "Low-rank approximation PSGD", - "func": C.scale_by_psgd_lra, - "params": { - "precond_lr": 0.1, - "rank": 4, - "param_count": 10000, - "precondition_frequency": 10, - "precond_init_scale": 1.0, - "precond_init_scale_scale": 0.0, - "precond_init_scale_power": 1.0, - "q_dtype": "float32", - "is_preconditioning": True, - "eps": 1e-8, - "betas": (0.9, 0.999), - }, - }, - { - "id": "delayed_psgd", - "name": "Delayed PSGD", - "icon": "⏱️", - "description": "PSGD with delayed preconditioner updates", - "func": functools.partial(C.scale_by_delayed_psgd, cached=False), - "params": { - "precond_lr": 0.1, - "max_size_triangular": 1024, - "precondition_frequency": 10, - "adaptive": False, - "store_triu_as_line": True, - "q_dtype": "float32", - "inverse_free": False, - "precond_init_scale": 1.0, - "precond_init_scale_scale": 0.0, - "precond_init_scale_power": 1.0, - "min_ndim_triangular": 2, - "memory_save_mode": None, - "dampening": 1.0, - "is_preconditioning": True, - "betas": (0.9, 0.999), - "eps": 1e-8, - "ortho_method": "qr", - "lower_bound_beta": 0.999, - "precond_update_power_iterations": 1, - }, - }, - { - "id": "delayed_psgd_lra", - "name": "Delayed PSGD LRA", - "icon": "⏰", - "description": "Delayed low-rank PSGD", - "func": C.scale_by_delayed_psgd_lra, - "params": { - "precond_lr": 0.1, - "rank": 4, - "param_count": 10000, - "precondition_frequency": 10, - "precond_init_scale": 1.0, - "precond_init_scale_scale": 0.0, - "precond_init_scale_power": 1.0, - "q_dtype": "float32", - "is_preconditioning": True, - "eps": 1e-8, - "betas": (0.9, 0.999), - }, - }, - ], - }, - "adaptive_lr": { - "name": "🎛️ Adaptive Learning Rate", - "color": COLORS["regularization"], - "components": [ - { - "id": "d_adapt", - "name": "D-Adaptation", - "icon": "📈", - "description": "Automatic learning rate adaptation", - "func": C.scale_by_d_adaptation, - "params": {"initial_d": 1.0}, - }, - { - "id": "lr_adapt", - "name": "LR Adaptation", - "icon": "🔄", - "description": "Learning rate adaptation", - "func": C.scale_by_lr_adaptation, - "params": {"initial_d": 1.0, "lr_lr": 0.1}, - }, - { - "id": "pointwise_lr_adapt", - "name": "Pointwise LR", - "icon": "🎚️", - "description": "Per-parameter learning rate", - "func": C.scale_by_pointwise_lr_adaptation, - "params": {"initial_d": 1.0, "lr_lr": 0.1}, - }, - ], - }, - "special": { - "name": "✨ Special Methods", - "color": COLORS["update"], - "components": [ - { - "id": "schedule_free", - "name": "Schedule-Free", - "icon": "🗓️", - "description": "No learning rate schedule needed", - "func": C.update_by_schedule_free, - "params": {"r": 0.0, "weight_lr_power": 2.0}, - }, - { - "id": "msam", - "name": "MSAM", - "icon": "🏔️", - "description": "Momentum SAM optimizer", - "func": C.update_by_msam, - "params": {"sam_step_size": 0.05}, - }, - { - "id": "mup", - "name": "μP Approx", - "icon": "📏", - "description": "Maximal update parametrization", - "func": C.mup_approx, - "params": {}, - }, - { - "id": "palm_beta2", - "name": "PALM β₂", - "icon": "🌴", - "description": "Dynamic β₂ scheduling for PALM", - "func": C.palm_beta2, - "params": {"beta2_scale": 0.8}, - }, - { - "id": "identity", - "name": "Identity", - "icon": "🔄", - "description": "Pass-through (no operation)", - "func": C.identity, - "params": {}, - }, - ], - }, -} - -# Pre-built optimizer recipes -RECIPES = { - "SGD": ["gradient_input"], - "SGD + Momentum": ["gradient_input", "heavyball"], - "Adam": ["gradient_input", "adam_scale"], - "AdamW": ["gradient_input", "adam_scale", "weight_decay"], - "RMSprop": ["gradient_input", "rmsprop_scale"], - "Nesterov SGD": ["gradient_input", "nesterov"], - "SOAP": ["gradient_input", "soap"], - "Laprop": ["gradient_input", "laprop"], - "ADOPT": ["gradient_input", "adopt"], - "Sign SGD": ["gradient_input", "sign"], - "AdamW + Clipping": ["gradient_input", "grad_clip", "adam_scale", "weight_decay"], - "D-Adapted Adam": ["gradient_input", "adam_scale", "d_adapt"], - "PSGD": ["gradient_input", "psgd"], - "Schedule-Free AdamW": ["gradient_input", "adam_scale", "weight_decay", "schedule_free"], - "EMA SGD": ["gradient_input", "basic_ema"], - "Orthogonal Adam": ["gradient_input", "orthogonalize_update", "adam_scale"], - "PALM": ["gradient_input", "palm_beta2", "adam_scale"], - "Delayed PSGD": ["gradient_input", "delayed_psgd"], - "μP Adam": ["gradient_input", "mup", "adam_scale"], -} - - -def get_component_info(comp_id): - """Get component info by ID""" - for category in OPTIMIZER_BLOCKS.values(): - for comp in category["components"]: - if comp["id"] == comp_id: - return comp, category["color"] - return None, "#757575" - - -def create_pipeline_display(pipeline): - """Create HTML display for the current pipeline""" - if not pipeline or pipeline == ["gradient_input"]: - return """ -
- Drop components here to build your optimizer pipeline -
- """ - - blocks_html = "" - for i, comp_id in enumerate(pipeline): - comp_info, color = get_component_info(comp_id) - if comp_info: - show_arrow = i < len(pipeline) - 1 - blocks_html += f""" -
-
-
{comp_info["icon"]}
-
{comp_info["name"]}
- -
- {'
' if show_arrow else ""} -
- """ - - return f""" -
- {blocks_html} -
- """ - - -def build_optimizer_from_pipeline(pipeline: List[str], params, kwargs): - """Build optimizer from pipeline of component IDs""" - fns = [] - opt_params = { - "lr": 0.001, - "step": 1, # Required for many functions - "caution": False, - "weight_decay": 0.0, - **kwargs, - } - - for comp_id in pipeline: - if comp_id == "gradient_input": - continue # Skip the input block - - # Find component - comp_info, _ = get_component_info(comp_id) - if comp_info and comp_info["func"] is not None: - fns.append(comp_info["func"]) - # Update parameters, handling special cases - params_to_add = comp_info["params"].copy() - - # Handle special parameter mappings - if "beta" in params_to_add and "betas" not in params_to_add: - # Convert single beta to betas tuple for functions expecting it - beta = params_to_add.pop("beta") - if "betas" in opt_params: - opt_params["betas"] = (beta, opt_params["betas"][1]) - else: - opt_params["betas"] = (beta, 0.999) - - # First add component defaults, then override with user-provided kwargs - for key, value in params_to_add.items(): - if key not in kwargs: # Only use default if not provided by user - opt_params[key] = value - - if not fns: - # Default to simple gradient descent - return C.BaseOpt(params, opt_params, fns=[C.identity]) - - return C.BaseOpt(params, opt_params, fns=fns) - - -def run_optimization(problem_name: str, pipeline: List[str], steps: int, **kwargs) -> Tuple[Any, Dict]: - """Run optimization with the custom pipeline""" - problem = PROBLEMS[problem_name] - kwargs["lr"] = 10 ** kwargs.get("lr", -2) - - # Initialize parameters - init = problem.init() - x = torch.nn.Parameter(torch.tensor(init, dtype=torch.float32)) - params = [x] - - # Build optimizer from pipeline - optimizer = build_optimizer_from_pipeline(pipeline, params, kwargs) - - # Run optimization - trajectory = [] - losses = [] - gradients = [] - - for i in range(steps): - trajectory.append(x.detach().numpy().copy()) - - def closure(): - optimizer.zero_grad() - loss = problem.loss(x) - loss.backward() - if x.grad is not None: - gradients.append(x.grad.detach().numpy().copy()) - return loss - - loss = optimizer.step(closure) - optimizer.zero_grad() - losses.append(loss.item()) - - trajectory = np.array(trajectory) - - # Create visualization - fig = create_visualization(problem_name, trajectory, losses, gradients, pipeline) - - return fig, { - "trajectory": trajectory.tolist(), - "losses": losses, - "final_loss": losses[-1], - "steps_to_converge": find_convergence(losses), - } - - -def find_convergence(losses, threshold=1e-6): - """Find when optimization converged""" - if len(losses) < 10: - return len(losses) - - for i in range(10, len(losses)): - if abs(losses[i] - losses[i - 5]) < threshold: - return i - return len(losses) - - -def create_visualization(problem_name, trajectory, losses, gradients, pipeline): - """Create integrated visualization""" - problem = PROBLEMS[problem_name] - - if problem.dim == 2: - return create_2d_visualization(problem_name, trajectory, losses, gradients, pipeline) - return create_highdim_visualization(problem_name, trajectory, losses, gradients, pipeline) - - -def create_2d_visualization(problem_name, trajectory, losses, gradients, pipeline): - """Create visualization for 2D problems""" - problem = PROBLEMS[problem_name] - bounds = problem.bounds() - - # Create subplots - fig = make_subplots( - rows=2, - cols=2, - subplot_titles=("Optimization Landscape", "Pipeline Architecture", "Loss Curve", "Learning Dynamics"), - column_widths=[0.6, 0.4], - row_heights=[0.6, 0.4], - specs=[[{"type": "contour"}, {"type": "scatter"}], [{"type": "scatter"}, {"type": "scatter"}]], - ) - - # 1. Optimization landscape - x = np.linspace(bounds[0][0], bounds[0][1], 100) - y = np.linspace(bounds[1][0], bounds[1][1], 100) - X, Y = np.meshgrid(x, y) - Z = np.zeros_like(X) - - for i in range(X.shape[0]): - for j in range(X.shape[1]): - point = torch.tensor([X[i, j], Y[i, j]], dtype=torch.float32) - Z[i, j] = problem.loss(point).item() - - # Contour plot - fig.add_trace( - go.Contour( - x=x, - y=y, - z=Z, - colorscale=[[0, "#e3f2fd"], [0.5, "#2196f3"], [1, "#0d47a1"]], - showscale=False, - contours=dict( - start=0, - end=Z.max(), - size=Z.max() / 15, - ), - ), - row=1, - col=1, - ) - - # Add optimization path - if len(trajectory) > 0: - colors = np.linspace(0, 1, len(trajectory)) - - for i in range(1, len(trajectory)): - fig.add_trace( - go.Scatter( - x=[trajectory[i - 1, 0], trajectory[i, 0]], - y=[trajectory[i - 1, 1], trajectory[i, 1]], - mode="lines", - line=dict(color=f"rgba(255, {int(111 * (1 - colors[i]))}, 0, {0.3 + 0.7 * colors[i]})", width=3), - showlegend=False, - hoverinfo="skip", - ), - row=1, - col=1, - ) - - # Start and end points - fig.add_trace( - go.Scatter( - x=[trajectory[0, 0]], - y=[trajectory[0, 1]], - mode="markers+text", - marker=dict(size=12, color="#4caf50", line=dict(color="white", width=2)), - text=["Start"], - textposition="top center", - showlegend=False, - ), - row=1, - col=1, - ) - - fig.add_trace( - go.Scatter( - x=[trajectory[-1, 0]], - y=[trajectory[-1, 1]], - mode="markers+text", - marker=dict(size=12, color="#ff6f00", line=dict(color="white", width=2)), - text=["End"], - textposition="top center", - showlegend=False, - ), - row=1, - col=1, - ) - - # 2. Pipeline visualization - block_x = [] - block_y = [] - block_text = [] - block_colors = [] - - for i, comp_id in enumerate(pipeline): - block_x.append(i / (len(pipeline) - 1) if len(pipeline) > 1 else 0.5) - block_y.append(0.5) - - comp_info, comp_color = get_component_info(comp_id) - block_text.append(comp_info["icon"] if comp_info else "?") - block_colors.append(comp_color) - - fig.add_trace( - go.Scatter( - x=block_x, - y=block_y, - mode="markers+text", - marker=dict(size=40, color=block_colors, line=dict(color="white", width=2)), - text=block_text, - textposition="middle center", - showlegend=False, - hoverinfo="skip", - ), - row=1, - col=2, - ) - - # Add arrows between blocks - for i in range(len(pipeline) - 1): - fig.add_annotation( - x=block_x[i + 1], - y=0.5, - ax=block_x[i], - ay=0.5, - xref="x2", - yref="y2", - axref="x2", - ayref="y2", - showarrow=True, - arrowhead=2, - arrowsize=1, - arrowwidth=2, - arrowcolor="#ff6f00", - row=1, - col=2, - ) - - # 3. Loss curve - fig.add_trace( - go.Scatter( - x=list(range(len(losses))), - y=losses, - mode="lines", - line=dict(color="#ff6f00", width=3), - fill="tozeroy", - fillcolor="rgba(255, 111, 0, 0.1)", - showlegend=False, - ), - row=2, - col=1, - ) - - # 4. Learning dynamics (gradient norm) - if gradients: - grad_norms = [np.linalg.norm(g) for g in gradients] - fig.add_trace( - go.Scatter( - x=list(range(len(grad_norms))), - y=grad_norms, - mode="lines", - line=dict(color="#2196f3", width=2), - fill="tozeroy", - fillcolor="rgba(33, 150, 243, 0.1)", - showlegend=False, - ), - row=2, - col=2, - ) - - # Update layout - fig.update_layout( - height=700, - showlegend=False, - paper_bgcolor="white", - plot_bgcolor="white", - margin=dict(l=40, r=40, t=60, b=40), - ) - - # Update axes - fig.update_xaxes(showgrid=True, gridcolor="#f0f0f0") - fig.update_yaxes(showgrid=True, gridcolor="#f0f0f0") - - # Pipeline plot - fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=2) - fig.update_yaxes(showticklabels=False, showgrid=False, range=[0, 1], row=1, col=2) - - # Loss plot - fig.update_xaxes(title_text="Iteration", row=2, col=1) - fig.update_yaxes(title_text="Loss", type="log", row=2, col=1) - - # Gradient plot - fig.update_xaxes(title_text="Iteration", row=2, col=2) - fig.update_yaxes(title_text="Gradient Norm", row=2, col=2) - - return fig - - -def create_highdim_visualization(problem_name, trajectory, losses, gradients, pipeline): - """Create visualization for high-dimensional problems using PCA""" - problem = PROBLEMS[problem_name] - - # Create subplots - fig = make_subplots( - rows=2, - cols=2, - subplot_titles=("PCA Trajectory Projection", "Pipeline Architecture", "Loss Curve", "Learning Dynamics"), - column_widths=[0.6, 0.4], - row_heights=[0.6, 0.4], - specs=[[{"type": "contour"}, {"type": "scatter"}], [{"type": "scatter"}, {"type": "scatter"}]], - ) - - # 1. PCA projection of trajectory - if len(trajectory) > 2: - # Fit PCA on trajectory - trajectory_array = np.array(trajectory) - pca = PCA(n_components=min(2, trajectory_array.shape[1])) - trajectory_2d = pca.fit_transform(trajectory_array) - - # Create loss landscape in PCA space - # Determine bounds based on trajectory - margin = 0.2 - x_min, x_max = trajectory_2d[:, 0].min(), trajectory_2d[:, 0].max() - x_range = x_max - x_min - x_min -= margin * x_range - x_max += margin * x_range - - if trajectory_2d.shape[1] > 1: - y_min, y_max = trajectory_2d[:, 1].min(), trajectory_2d[:, 1].max() - y_range = y_max - y_min - y_min -= margin * y_range - y_max += margin * y_range - else: - y_min, y_max = -1, 1 - - # Create grid in PCA space - n_points = 50 - x_grid = np.linspace(x_min, x_max, n_points) - y_grid = np.linspace(y_min, y_max, n_points) - X_grid, Y_grid = np.meshgrid(x_grid, y_grid) - - # Compute losses on grid - Z_grid = np.zeros_like(X_grid) - mean_trajectory = pca.mean_ - - for i in range(n_points): - for j in range(n_points): - # Map 2D point back to high-dimensional space - pca_coords = np.array([X_grid[i, j], Y_grid[i, j]]) - if trajectory_2d.shape[1] == 1: - pca_coords = pca_coords[:1] # Use only first coordinate - - # Reconstruct high-dimensional point - high_dim_point = mean_trajectory + pca_coords @ pca.components_[: len(pca_coords)] - - # Evaluate loss - x_tensor = torch.tensor(high_dim_point, dtype=torch.float32) - Z_grid[i, j] = problem.loss(x_tensor).item() - - # Add contour plot - fig.add_trace( - go.Contour( - x=x_grid, - y=y_grid, - z=Z_grid, - colorscale=[[0, "#e3f2fd"], [0.5, "#2196f3"], [1, "#0d47a1"]], - showscale=False, - contours=dict( - start=Z_grid.min(), - end=Z_grid.max(), - size=(Z_grid.max() - Z_grid.min()) / 15, - ), - ), - row=1, - col=1, - ) - - # Add trajectory on top of contour - colors = np.linspace(0, 1, len(trajectory_2d)) - - for i in range(1, len(trajectory_2d)): - fig.add_trace( - go.Scatter( - x=[trajectory_2d[i - 1, 0], trajectory_2d[i, 0]], - y=[trajectory_2d[i - 1, 1] if trajectory_2d.shape[1] > 1 else [0, 0]], - mode="lines", - line=dict(color=f"rgba(255, {int(111 * (1 - colors[i]))}, 0, {0.3 + 0.7 * colors[i]})", width=3), - showlegend=False, - hoverinfo="skip", - ), - row=1, - col=1, - ) - - # Start and end points - fig.add_trace( - go.Scatter( - x=[trajectory_2d[0, 0]], - y=[trajectory_2d[0, 1] if trajectory_2d.shape[1] > 1 else 0], - mode="markers+text", - marker=dict(size=12, color="#4caf50", line=dict(color="white", width=2)), - text=["Start"], - textposition="top center", - showlegend=False, - ), - row=1, - col=1, - ) - - fig.add_trace( - go.Scatter( - x=[trajectory_2d[-1, 0]], - y=[trajectory_2d[-1, 1] if trajectory_2d.shape[1] > 1 else 0], - mode="markers+text", - marker=dict(size=12, color="#ff6f00", line=dict(color="white", width=2)), - text=["End"], - textposition="top center", - showlegend=False, - ), - row=1, - col=1, - ) - - # Add explained variance text - if trajectory_2d.shape[1] > 1: - explained_var = pca.explained_variance_ratio_ - fig.add_annotation( - x=0.5, - y=1.1, - xref="x domain", - yref="y domain", - text=f"Explained variance: PC1={explained_var[0]:.1%}, PC2={explained_var[1]:.1%}", - showarrow=False, - row=1, - col=1, - ) - - # 2. Pipeline visualization (same as 2D) - block_x = [] - block_y = [] - block_text = [] - block_colors = [] - - for i, comp_id in enumerate(pipeline): - block_x.append(i / (len(pipeline) - 1) if len(pipeline) > 1 else 0.5) - block_y.append(0.5) - - comp_info, comp_color = get_component_info(comp_id) - block_text.append(comp_info["icon"] if comp_info else "?") - block_colors.append(comp_color) - - fig.add_trace( - go.Scatter( - x=block_x, - y=block_y, - mode="markers+text", - marker=dict(size=40, color=block_colors, line=dict(color="white", width=2)), - text=block_text, - textposition="middle center", - showlegend=False, - hoverinfo="skip", - ), - row=1, - col=2, - ) - - # Add arrows between blocks - for i in range(len(pipeline) - 1): - fig.add_annotation( - x=block_x[i + 1], - y=0.5, - ax=block_x[i], - ay=0.5, - xref="x2", - yref="y2", - axref="x2", - ayref="y2", - showarrow=True, - arrowhead=2, - arrowsize=1, - arrowwidth=2, - arrowcolor="#ff6f00", - row=1, - col=2, - ) - - # 3. Loss curve - fig.add_trace( - go.Scatter( - x=list(range(len(losses))), - y=losses, - mode="lines", - line=dict(color="#ff6f00", width=3), - fill="tozeroy", - fillcolor="rgba(255, 111, 0, 0.1)", - showlegend=False, - ), - row=2, - col=1, - ) - - # 4. Learning dynamics (gradient norm) - if gradients: - grad_norms = [np.linalg.norm(g) for g in gradients] - fig.add_trace( - go.Scatter( - x=list(range(len(grad_norms))), - y=grad_norms, - mode="lines", - line=dict(color="#2196f3", width=2), - fill="tozeroy", - fillcolor="rgba(33, 150, 243, 0.1)", - showlegend=False, - ), - row=2, - col=2, - ) - - # Update layout - fig.update_layout( - height=700, - showlegend=False, - paper_bgcolor="white", - plot_bgcolor="white", - margin=dict(l=40, r=40, t=60, b=40), - ) - - # Update axes - fig.update_xaxes(showgrid=True, gridcolor="#f0f0f0") - fig.update_yaxes(showgrid=True, gridcolor="#f0f0f0") - - # PCA plot - fig.update_xaxes(title_text="First Principal Component", row=1, col=1) - fig.update_yaxes(title_text="Second Principal Component", row=1, col=1) - - # Pipeline plot - fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=2) - fig.update_yaxes(showticklabels=False, showgrid=False, range=[0, 1], row=1, col=2) - - # Loss plot - fig.update_xaxes(title_text="Iteration", row=2, col=1) - fig.update_yaxes(title_text="Loss", type="log", row=2, col=1) - - # Gradient plot - fig.update_xaxes(title_text="Iteration", row=2, col=2) - fig.update_yaxes(title_text="Gradient Norm", row=2, col=2) - - return fig - - -def create_app(): - with gr.Blocks(theme=gr.themes.Base()) as app: - # Custom CSS - gr.HTML(""" - - """) - - # Header - gr.Markdown(""" - # 🧩 HeavyBall Chainable Optimizer Playground - - ### Build custom optimizers by combining components like LEGO blocks! - - Click components to add them to your pipeline. Each component transforms the gradient in a specific way - stack them to create powerful optimization algorithms! - """) - - # Hidden state for pipeline - pipeline_state = gr.State(["gradient_input"]) - - with gr.Row(): - # Left column - Component palette - with gr.Column(scale=1): - gr.Markdown("### 🎨 Component Palette") - gr.Markdown("*Click blocks to add to pipeline*") - - # Component buttons organized by category - for cat_id, category in OPTIMIZER_BLOCKS.items(): - if cat_id == "gradient": # Skip gradient input in palette - continue - - gr.HTML(f'
{category["name"]}
') - - with gr.Row(): - for comp in category["components"]: - btn = gr.Button( - value=f"{comp['icon']} {comp['name']}", elem_id=f"btn_{comp['id']}", size="sm" - ) - # Store component ID in button for click handler - btn.click( - fn=lambda p, cid=comp["id"]: p + [cid], - inputs=[pipeline_state], - outputs=[pipeline_state], - ) - - # Recipe selector - gr.Markdown("### 📚 Preset Recipes") - recipe_dropdown = gr.Dropdown(choices=list(RECIPES.keys()), value=None, label="Load a preset optimizer") - load_recipe_btn = gr.Button("Load Recipe", size="sm") - - # Center column - Main visualization - with gr.Column(scale=2): - # Pipeline builder - gr.Markdown("### 🔧 Pipeline Builder") - pipeline_display = gr.HTML() - - with gr.Row(): - clear_pipeline_btn = gr.Button("🗑️ Clear Pipeline", size="sm", variant="secondary") - refresh_btn = gr.Button("🔄 Refresh Display", size="sm") - - # Visualization - viz_plot = gr.Plot(label="") - - # Run button - run_btn = gr.Button("🚀 Run Optimization", variant="primary", size="lg") - - # Right column - Parameters and metrics - with gr.Column(scale=1): - gr.Markdown("### ⚙️ Parameters") - - problem_select = gr.Dropdown(choices=list(PROBLEMS.keys()), value="Rosenbrock (2D)", label="Problem") - - lr_slider = gr.Slider(minimum=-4, maximum=-1, value=-2, step=0.1, label="Learning Rate (10^x)") - - steps_slider = gr.Slider(minimum=10, maximum=500, value=200, step=10, label="Steps") - - # Component-specific parameters - with gr.Accordion("Advanced Parameters", open=False): - beta_slider = gr.Slider(minimum=0.0, maximum=0.999, value=0.9, step=0.001, label="Momentum β") - beta2_slider = gr.Slider(minimum=0.0, maximum=0.999, value=0.999, step=0.001, label="Adam β₂") - eps_slider = gr.Slider(minimum=1e-8, maximum=1e-4, value=1e-8, step=1e-8, label="Epsilon") - - # Weight decay parameters - weight_decay_slider = gr.Slider( - minimum=0.0, maximum=0.1, value=0.01, step=0.001, label="Weight Decay" - ) - ema_beta_slider = gr.Slider( - minimum=0.9, maximum=0.999, value=0.999, step=0.001, label="EMA β (for weight decay)" - ) - - # Gradient clipping - max_norm_slider = gr.Slider( - minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Gradient Clip Norm" - ) - - # Preconditioning parameters - shampoo_beta_slider = gr.Slider( - minimum=0.9, maximum=0.999, value=0.99, step=0.001, label="Shampoo β" - ) - precond_lr_slider = gr.Slider( - minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Preconditioner LR" - ) - precondition_frequency_slider = gr.Slider( - minimum=1, maximum=100, value=10, step=1, label="Precondition Frequency" - ) - rank_slider = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Low-rank Approximation Rank") - param_count_slider = gr.Slider( - minimum=1000, maximum=100000, value=10000, step=1000, label="Param Count (PSGD LRA)" - ) - - # Adaptive LR parameters - initial_d_slider = gr.Slider( - minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Initial D (D-adaptation)" - ) - lr_lr_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Learning Rate LR") - - # Special method parameters - r_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Schedule-Free r") - weight_lr_power_slider = gr.Slider( - minimum=0.0, maximum=4.0, value=2.0, step=0.1, label="Weight LR Power" - ) - sam_step_size_slider = gr.Slider( - minimum=0.01, maximum=0.5, value=0.05, step=0.01, label="SAM Step Size" - ) - beta2_scale_slider = gr.Slider( - minimum=0.5, maximum=1.0, value=0.8, step=0.01, label="PALM β₂ Scale" - ) - - # Sign SGD parameters - graft_checkbox = gr.Checkbox(value=True, label="Graft (Sign SGD)") - - gr.Markdown("### 📊 Metrics") - - final_loss_display = gr.Textbox(label="Final Loss", value="-") - convergence_display = gr.Textbox(label="Steps to Converge", value="-") - - # Footer - gr.Markdown(""" - --- - ### 💡 How it works: - - 1. **Start with Gradient** - Every pipeline begins with raw gradients - 2. **Add Transforms** - Click components to add them to your pipeline - 3. **Order Matters** - Components are applied in sequence - 4. **Run & Compare** - See how different combinations perform! - - **Example combinations:** - - Adam = Gradient → Adam Scaling - - AdamW = Gradient → Adam Scaling → Weight Decay - - Momentum SGD = Gradient → Heavy Ball - """) - - # Event handlers - def update_display(pipeline): - return create_pipeline_display(pipeline) - - def load_recipe(recipe_name): - if recipe_name in RECIPES: - return RECIPES[recipe_name].copy() - return ["gradient_input"] - - def clear_pipeline(): - return ["gradient_input"] - - def remove_block(pipeline, index): - """Remove a block from the pipeline""" - if 0 <= index < len(pipeline) and pipeline[index] != "gradient_input": - new_pipeline = pipeline.copy() - new_pipeline.pop(index) - return new_pipeline - return pipeline - - def run_optimization_handler( - problem, - pipeline, - steps, - lr, - beta, - beta2, - eps, - weight_decay, - ema_beta, - max_norm, - shampoo_beta, - precond_lr, - precondition_frequency, - rank, - param_count, - initial_d, - lr_lr, - r, - weight_lr_power, - sam_step_size, - beta2_scale, - graft, - ): - """Run optimization with current pipeline""" - if not pipeline or len(pipeline) == 1: - pipeline = ["gradient_input"] - - # Run optimization with all parameters - fig, metrics = run_optimization( - problem, - pipeline, - steps, - lr=lr, - beta=beta, - betas=(beta, beta2), - beta2=beta2, - eps=eps, - weight_decay_to_ema=weight_decay, - weight_decay_to_init=weight_decay, - ema_beta=ema_beta, - max_norm=max_norm, - shampoo_beta=shampoo_beta, - precond_lr=precond_lr, - precondition_frequency=precondition_frequency, - rank=rank, - param_count=param_count, - initial_d=initial_d, - lr_lr=lr_lr, - r=r, - weight_lr_power=weight_lr_power, - sam_step_size=sam_step_size, - beta2_scale=beta2_scale, - graft=graft, - ) - - return fig, f"{metrics['final_loss']:.2e}", str(metrics["steps_to_converge"]) - - # Connect events - # Update display when pipeline changes - pipeline_state.change(fn=update_display, inputs=[pipeline_state], outputs=[pipeline_display]) - - # Recipe loading - load_recipe_btn.click(fn=load_recipe, inputs=[recipe_dropdown], outputs=[pipeline_state]) - - # Clear pipeline - clear_pipeline_btn.click(fn=clear_pipeline, outputs=[pipeline_state]) - - # Refresh display - refresh_btn.click(fn=update_display, inputs=[pipeline_state], outputs=[pipeline_display]) - - # Run optimization - run_btn.click( - fn=run_optimization_handler, - inputs=[ - problem_select, - pipeline_state, - steps_slider, - lr_slider, - beta_slider, - beta2_slider, - eps_slider, - weight_decay_slider, - ema_beta_slider, - max_norm_slider, - shampoo_beta_slider, - precond_lr_slider, - precondition_frequency_slider, - rank_slider, - param_count_slider, - initial_d_slider, - lr_lr_slider, - r_slider, - weight_lr_power_slider, - sam_step_size_slider, - beta2_scale_slider, - graft_checkbox, - ], - outputs=[viz_plot, final_loss_display, convergence_display], - ) - - # Add JavaScript for removing blocks - app.load( - None, - None, - None, - js=""" - function() { - // Function to remove pipeline blocks - window.removePipelineBlock = function(index) { - // This would need to trigger a Gradio event - console.log('Remove block at index:', index); - // In a real implementation, this would update the pipeline state - }; - - console.log('Playground initialized'); - } - """, - ) - - # Initialize display - app.load( - fn=lambda: ( - create_pipeline_display(["gradient_input"]), - run_optimization("Rosenbrock (2D)", ["gradient_input", "adam_scale"], 200, lr=-2)[0], - ), - outputs=[pipeline_display, viz_plot], - ) - - return app - - -if __name__ == "__main__": - app = create_app() - app.launch(share=False) diff --git a/interactive/static/init-globals.js b/interactive/static/init-globals.js deleted file mode 100644 index b820bc9..0000000 --- a/interactive/static/init-globals.js +++ /dev/null @@ -1,2 +0,0 @@ -// This file initializes global variables for the node editor -// It will be populated by Python when the page loads diff --git a/interactive/static/node-editor.js b/interactive/static/node-editor.js deleted file mode 100644 index 525597f..0000000 --- a/interactive/static/node-editor.js +++ /dev/null @@ -1,570 +0,0 @@ -console.log('Node editor script starting execution...'); - -class NodeEditor { - constructor(containerId) { - this.container = document.getElementById(containerId); - this.canvas = document.getElementById('canvas'); - this.palette = document.getElementById('palette'); - this.inspector = document.getElementById('inspector'); - this.nodes = []; - this.connections = []; - this.selectedNode = null; - this.draggingNode = null; - this.connecting = null; - this.nodeIdCounter = 0; - this.offset = { x: 0, y: 0 }; - - this.init(); - } - - init() { - // Set up SVG for connections - this.svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); - this.svg.style.position = 'absolute'; - this.svg.style.width = '100%'; - this.svg.style.height = '100%'; - this.svg.style.pointerEvents = 'none'; - this.svg.style.zIndex = '1'; - this.svg.style.top = '0'; - this.svg.style.left = '0'; - this.canvas.appendChild(this.svg); - - // Set up event listeners - this.setupPalette(); - this.setupCanvas(); - - // Add gradient input node by default - this.addNode('gradient_input', 100, 300); - } - - setupPalette() { - // Make palette items draggable - const items = this.palette.querySelectorAll('.palette-item'); - items.forEach(item => { - item.draggable = true; - item.addEventListener('dragstart', (e) => this.onPaletteDragStart(e)); - item.addEventListener('dragend', (e) => this.onPaletteDragEnd(e)); - }); - } - - setupCanvas() { - this.canvas.addEventListener('dragover', (e) => e.preventDefault()); - this.canvas.addEventListener('drop', (e) => this.onCanvasDrop(e)); - this.canvas.addEventListener('mousedown', (e) => this.onCanvasMouseDown(e)); - this.canvas.addEventListener('mousemove', (e) => this.onCanvasMouseMove(e)); - this.canvas.addEventListener('mouseup', (e) => this.onCanvasMouseUp(e)); - } - - onPaletteDragStart(e) { - e.target.classList.add('dragging'); - e.dataTransfer.effectAllowed = 'copy'; - e.dataTransfer.setData('nodeType', e.target.dataset.type); - } - - onPaletteDragEnd(e) { - e.target.classList.remove('dragging'); - } - - onCanvasDrop(e) { - e.preventDefault(); - const nodeType = e.dataTransfer.getData('nodeType'); - if (nodeType) { - const rect = this.canvas.getBoundingClientRect(); - const x = e.clientX - rect.left - 80; - const y = e.clientY - rect.top - 30; - this.addNode(nodeType, x, y); - this.updateCodePreview(); - } - } - - addNode(type, x, y) { - console.log(`Adding node: type=${type}, x=${x}, y=${y}`); - const nodeId = `node-${this.nodeIdCounter++}`; - const nodeData = window.CHAINABLE_FUNCTIONS[type] || { - description: type === 'gradient_input' ? 'Gradient Input' : 'Unknown', - color: '#666', - inputs: type === 'gradient_input' ? [] : ['grad'], - outputs: type === 'gradient_input' ? ['grad'] : [] - }; - - const node = document.createElement('div'); - node.className = 'node'; - node.id = nodeId; - node.style.left = x + 'px'; - node.style.top = y + 'px'; - node.style.setProperty('--node-color', nodeData.color); - node.style.borderLeftColor = nodeData.color; - node.style.borderLeftWidth = '4px'; - console.log('Created node element:', node); - - // Special styling for gradient input - if (type === 'gradient_input') { - node.classList.add('gradient-input'); - } - - node.innerHTML = ` -
${type === 'gradient_input' ? 'Gradient Input' : type.replace(/_/g, ' ')}
-
${nodeData.description}
- ${nodeData.inputs.length ? '
' : ''} - ${nodeData.outputs.length ? '
' : ''} - `; - - console.log('Canvas element:', this.canvas); - console.log('Canvas dimensions:', this.canvas.offsetWidth, 'x', this.canvas.offsetHeight); - this.canvas.appendChild(node); - console.log('Node appended to canvas. Total nodes in canvas:', this.canvas.querySelectorAll('.node').length); - - // Set up node interactions - node.addEventListener('mousedown', (e) => this.onNodeMouseDown(e, nodeId)); - - // Set up port interactions - const inputPort = node.querySelector('.node-port-input'); - const outputPort = node.querySelector('.node-port-output'); - - if (inputPort) { - inputPort.addEventListener('mousedown', (e) => this.onPortMouseDown(e, nodeId, 'input')); - } - if (outputPort) { - outputPort.addEventListener('mousedown', (e) => this.onPortMouseDown(e, nodeId, 'output')); - } - - // Store node data - this.nodes.push({ - id: nodeId, - type: type, - element: node, - x: x, - y: y, - params: nodeData.params || {} - }); - - return nodeId; - } - - onNodeMouseDown(e, nodeId) { - if (e.target.classList.contains('node-port')) return; - - e.preventDefault(); - this.selectNode(nodeId); - - const node = this.nodes.find(n => n.id === nodeId); - this.draggingNode = node; - - const rect = node.element.getBoundingClientRect(); - const canvasRect = this.canvas.getBoundingClientRect(); - this.offset = { - x: e.clientX - rect.left + canvasRect.left, - y: e.clientY - rect.top + canvasRect.top - }; - } - - onPortMouseDown(e, nodeId, portType) { - e.preventDefault(); - e.stopPropagation(); - - const port = e.target; - const rect = port.getBoundingClientRect(); - const canvasRect = this.canvas.getBoundingClientRect(); - - this.connecting = { - nodeId: nodeId, - portType: portType, - startX: rect.left + rect.width / 2 - canvasRect.left, - startY: rect.top + rect.height / 2 - canvasRect.top - }; - - // Create preview connection - const line = document.createElementNS('http://www.w3.org/2000/svg', 'path'); - line.classList.add('connection-preview'); - line.id = 'preview-connection'; - this.svg.appendChild(line); - } - - onCanvasMouseMove(e) { - const rect = this.canvas.getBoundingClientRect(); - const x = e.clientX - rect.left; - const y = e.clientY - rect.top; - - if (this.draggingNode) { - this.draggingNode.element.style.left = (x - this.offset.x) + 'px'; - this.draggingNode.element.style.top = (y - this.offset.y) + 'px'; - this.draggingNode.x = x - this.offset.x; - this.draggingNode.y = y - this.offset.y; - this.updateConnections(); - } - - if (this.connecting) { - const preview = document.getElementById('preview-connection'); - if (preview) { - const path = this.createConnectionPath( - this.connecting.startX, - this.connecting.startY, - x, - y - ); - preview.setAttribute('d', path); - } - } - } - - onCanvasMouseUp(e) { - if (this.connecting) { - const preview = document.getElementById('preview-connection'); - if (preview) preview.remove(); - - // Check if we're over a port - const target = document.elementFromPoint(e.clientX, e.clientY); - if (target && target.classList.contains('node-port')) { - const targetNode = target.closest('.node'); - const targetNodeId = targetNode.id; - const targetPortType = target.classList.contains('node-port-input') ? 'input' : 'output'; - - // Validate connection - if (this.canConnect(this.connecting.nodeId, this.connecting.portType, targetNodeId, targetPortType)) { - this.addConnection(this.connecting.nodeId, this.connecting.portType, targetNodeId, targetPortType); - this.updateCodePreview(); - } - } - - this.connecting = null; - } - - this.draggingNode = null; - } - - canConnect(fromNodeId, fromPortType, toNodeId, toPortType) { - // Can't connect to same node - if (fromNodeId === toNodeId) return false; - - // Must connect output to input - if (fromPortType === toPortType) return false; - - // Check if connection already exists - const exists = this.connections.some(c => - (c.from.nodeId === fromNodeId && c.to.nodeId === toNodeId) || - (c.from.nodeId === toNodeId && c.to.nodeId === fromNodeId) - ); - - return !exists; - } - - addConnection(fromNodeId, fromPortType, toNodeId, toPortType) { - // Ensure output connects to input - if (fromPortType === 'input') { - [fromNodeId, toNodeId] = [toNodeId, fromNodeId]; - [fromPortType, toPortType] = [toPortType, fromPortType]; - } - - const connection = { - from: { nodeId: fromNodeId, portType: fromPortType }, - to: { nodeId: toNodeId, portType: toPortType }, - element: null - }; - - const line = document.createElementNS('http://www.w3.org/2000/svg', 'path'); - line.classList.add('connection'); - line.classList.add('animated'); - this.svg.appendChild(line); - - connection.element = line; - this.connections.push(connection); - - this.updateConnections(); - } - - updateConnections() { - this.connections.forEach(conn => { - const fromNode = this.nodes.find(n => n.id === conn.from.nodeId); - const toNode = this.nodes.find(n => n.id === conn.to.nodeId); - - if (fromNode && toNode) { - const fromPort = fromNode.element.querySelector('.node-port-output'); - const toPort = toNode.element.querySelector('.node-port-input'); - - if (fromPort && toPort) { - const fromRect = fromPort.getBoundingClientRect(); - const toRect = toPort.getBoundingClientRect(); - const canvasRect = this.canvas.getBoundingClientRect(); - - const x1 = fromRect.left + fromRect.width / 2 - canvasRect.left; - const y1 = fromRect.top + fromRect.height / 2 - canvasRect.top; - const x2 = toRect.left + toRect.width / 2 - canvasRect.left; - const y2 = toRect.top + toRect.height / 2 - canvasRect.top; - - const path = this.createConnectionPath(x1, y1, x2, y2); - conn.element.setAttribute('d', path); - - // Store path coordinates for particle animation - conn.pathCoords = { x1, y1, x2, y2 }; - } - } - }); - } - - // Animate gradient flow particles - startGradientFlow() { - if (this.flowInterval) return; - - this.flowInterval = setInterval(() => { - this.connections.forEach(conn => { - if (conn.pathCoords) { - this.createFlowParticle(conn.pathCoords); - } - }); - }, 500); - } - - stopGradientFlow() { - if (this.flowInterval) { - clearInterval(this.flowInterval); - this.flowInterval = null; - } - - // Remove all particles - const particles = this.canvas.querySelectorAll('.flow-particle'); - particles.forEach(p => p.remove()); - } - - createFlowParticle(coords) { - const particle = document.createElement('div'); - particle.className = 'flow-particle'; - particle.style.left = coords.x1 + 'px'; - particle.style.top = coords.y1 + 'px'; - this.canvas.appendChild(particle); - - // Animate along bezier curve - const duration = 2000; - const startTime = Date.now(); - - const animate = () => { - const elapsed = Date.now() - startTime; - const t = Math.min(elapsed / duration, 1); - - if (t >= 1) { - particle.remove(); - return; - } - - // Calculate position on bezier curve - const dx = Math.abs(coords.x2 - coords.x1); - const cp1x = coords.x1 + dx * 0.5; - const cp2x = coords.x2 - dx * 0.5; - - // Bezier curve formula - const x = Math.pow(1-t, 3) * coords.x1 + - 3 * Math.pow(1-t, 2) * t * cp1x + - 3 * (1-t) * Math.pow(t, 2) * cp2x + - Math.pow(t, 3) * coords.x2; - - const y = Math.pow(1-t, 3) * coords.y1 + - 3 * Math.pow(1-t, 2) * t * coords.y1 + - 3 * (1-t) * Math.pow(t, 2) * coords.y2 + - Math.pow(t, 3) * coords.y2; - - particle.style.left = x + 'px'; - particle.style.top = y + 'px'; - particle.style.opacity = Math.sin(t * Math.PI); - - requestAnimationFrame(animate); - }; - - requestAnimationFrame(animate); - } - - createConnectionPath(x1, y1, x2, y2) { - const dx = Math.abs(x2 - x1); - const cp1x = x1 + dx * 0.5; - const cp2x = x2 - dx * 0.5; - return `M ${x1} ${y1} C ${cp1x} ${y1}, ${cp2x} ${y2}, ${x2} ${y2}`; - } - - selectNode(nodeId) { - // Deselect previous - if (this.selectedNode) { - this.selectedNode.element.classList.remove('selected'); - } - - // Select new - const node = this.nodes.find(n => n.id === nodeId); - if (node) { - node.element.classList.add('selected'); - this.selectedNode = node; - this.showInspector(node); - } - } - - showInspector(node) { - this.inspector.classList.add('active'); - - const nodeData = window.CHAINABLE_FUNCTIONS[node.type]; - if (!nodeData || !nodeData.params) return; - - // Build inspector content - let html = `

${node.type.replace(/_/g, ' ')}

`; - - Object.entries(nodeData.params).forEach(([key, defaultValue]) => { - const value = node.params[key] || defaultValue; - html += ` -
-
${key}
- -
- `; - }); - - this.inspector.innerHTML = html; - - // Set up parameter change listeners - this.inspector.querySelectorAll('.inspector-input').forEach(input => { - input.addEventListener('change', (e) => { - const param = e.target.dataset.param; - let value = e.target.value; - - // Convert to appropriate type - if (e.target.type === 'number') { - value = parseFloat(value); - } - - node.params[param] = value; - this.updateCodePreview(); - }); - }); - } - - updateCodePreview() { - const codePreview = document.getElementById('code-output'); - if (!codePreview) return; - - // Generate Python code from pipeline - const pipelineData = this.exportPipeline(); - const nodes = pipelineData.nodes.filter(n => n.type !== 'gradient_input'); - - if (nodes.length === 0) { - codePreview.textContent = '# Add optimizer components to build your pipeline'; - return; - } - - let code = 'optimizer = BaseOpt(\n'; - code += ' model.parameters(),\n'; - code += ' lr=0.001,\n'; - - // Add relevant parameters - const params = {}; - nodes.forEach(node => { - Object.assign(params, node.params); - }); - - Object.entries(params).forEach(([key, value]) => { - if (typeof value === 'string') { - code += ` ${key}="${value}",\n`; - } else { - code += ` ${key}=${value},\n`; - } - }); - - code += ' fns=[\n'; - nodes.forEach(node => { - code += ` ${node.type},\n`; - }); - code += ' ]\n'; - code += ')'; - - codePreview.textContent = code; - - // Update hidden pipeline data - const pipelineInput = document.getElementById('pipeline-data'); - if (pipelineInput) { - pipelineInput.value = JSON.stringify(pipelineData); - } - } - - exportPipeline() { - return { - nodes: this.nodes.map(n => ({ - id: n.id, - type: n.type, - x: n.x, - y: n.y, - params: n.params - })), - connections: this.connections.map(c => ({ - from: c.from, - to: c.to - })) - }; - } - - loadRecipe(recipe) { - // Clear existing nodes except gradient input - this.nodes = this.nodes.filter(n => n.type === 'gradient_input'); - this.connections = []; - this.selectedNode = null; - - // Clear canvas except gradient input - const nodesToRemove = this.canvas.querySelectorAll('.node:not([id="node-0"])'); - nodesToRemove.forEach(n => n.remove()); - - // Clear connections - this.svg.innerHTML = ''; - - // Add nodes from recipe - let lastNodeId = 'node-0'; // gradient input - let x = 300; - const y = 300; - - recipe.forEach((item, index) => { - const nodeId = this.addNode(item.name, x, y); - const node = this.nodes.find(n => n.id === nodeId); - - // Set parameters - if (item.params) { - node.params = item.params; - } - - // Connect to previous node - this.addConnection(lastNodeId, 'output', nodeId, 'input'); - - lastNodeId = nodeId; - x += 200; - }); - - this.updateCodePreview(); - } -} - -// Initialize when DOM is ready -function initializeNodeEditor() { - try { - console.log('Attempting to initialize NodeEditor...'); - const editorElement = document.getElementById('node-editor'); - if (!editorElement) { - console.error('node-editor element not found! Retrying in 100ms...'); - setTimeout(initializeNodeEditor, 100); - return; - } - console.log('Found node-editor element:', editorElement); - console.log('Editor dimensions:', editorElement.offsetWidth, 'x', editorElement.offsetHeight); - - if (window.nodeEditor) { - console.log('NodeEditor already initialized'); - return; - } - - window.nodeEditor = new NodeEditor('node-editor'); - console.log('NodeEditor initialized successfully'); - } catch (error) { - console.error('Error initializing NodeEditor:', error); - } -} - -// Try multiple initialization methods -document.addEventListener('DOMContentLoaded', initializeNodeEditor); -// Also try after a delay in case content loads after DOMContentLoaded -setTimeout(initializeNodeEditor, 100); -setTimeout(initializeNodeEditor, 500); -setTimeout(initializeNodeEditor, 1000); diff --git a/pyproject.toml b/pyproject.toml index 36916e8..7c2d78d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "heavyball" -description = "Compile-first PyTorch optimizer library — AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features" -version = "2.3.2" +description = "Compile-first PyTorch optimizer library - AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features" +version = "3.0.0" authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }] classifiers = ["Intended Audience :: Developers", "Intended Audience :: Science/Research", diff --git a/scripts/migrate_optimizer_state.py b/scripts/migrate_optimizer_state.py index fb9f6e2..162538c 100644 --- a/scripts/migrate_optimizer_state.py +++ b/scripts/migrate_optimizer_state.py @@ -1,18 +1,21 @@ #!/usr/bin/env python3 """ -Utility to migrate HeavyBall 1.x optimizer state dicts to the 2.0.0 layout. +Utility to migrate HeavyBall optimizer state dicts to the current (3.x) layout. -The script rewrites per-parameter state keys to the new transform-indexed names, -reshapes state storage so each parameter-view owns its own dictionary, and -injects the HeavyBall-specific metadata block expected by 2.0.0 optimizers. +Supports two migration paths: +- 1.x → 3.x: rewrites per-parameter state keys, reshapes storage, fixes param groups and metadata +- 2.x → 3.x: renames foreach→multi_tensor in param groups, strips stale metadata + +Version auto-detection: +- 1.x: flat per-parameter state (not nested by view index) +- 2.x: nested state with 'foreach' in param groups +- 3.x: nested state with 'multi_tensor' in param groups (already current) """ from __future__ import annotations import functools import importlib -import pickle -import random from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple @@ -20,6 +23,56 @@ import torch import typer +_CLASS_RENAMES = { + # Foreach* internal names → canonical 3.x names + "ForeachAdamW": "AdamW", + "ForeachNAdam": "NAdam", + "ForeachAdEMAMix": "AdEMAMix", + "ForeachAdamC": "AdamC", + "ForeachRMSprop": "RMSprop", + "ForeachSFAdamW": "SFAdamW", + "ForeachADOPT": "ADOPT", + "ForeachMuon": "Muon", + "ForeachLaProp": "LaProp", + "ForeachSOAP": "SOAP", + "ForeachSOAPNAdam": "SOAPNAdam", + "ForeachSOAPAdEMAMix": "SOAPAdEMAMix", + "ForeachSignLaProp": "SignLaProp", + "ForeachSOLP": "SOLP", + "ForeachPSGDKron": "PSGDKron", + "ForeachPSGDLRA": "PSGDLRA", + # Deleted subclasses → parent (all were __init__-only default overrides) + "PaLMForeachSFAdamW": "SFAdamW", + "PaLMForeachSOAP": "SOAP", + "PrecondScheduleForeachSOAP": "SOAP", + "PrecondSchedulePaLMForeachSOAP": "SOAP", + "ForeachPurePSGD": "PSGDKron", + "ForeachCachedPSGDKron": "PSGDKron", + "ForeachCachedDelayedPSGDKron": "PSGDKron", + "ForeachDelayedPSGD": "PSGDKron", + "ForeachCachedNewtonPSGD": "PSGDKron", + "NewtonHybrid2PSGDKron": "PSGDKron", + "ForeachDelayedPSGDLRA": "PSGDLRA", + "ForeachNewtonPSGDLRA": "PSGDLRA", + "NewtonHybrid2PSGDLRA": "PSGDLRA", + # 2.x public aliases that pointed to Foreach* classes (now removed) + "PaLMSOAP": "SOAP", + "PaLMSFAdamW": "SFAdamW", + "PalmForEachSoap": "SOAP", + "PrecondScheduleSOAP": "SOAP", + "PrecondSchedulePaLMSOAP": "SOAP", + "PurePSGD": "PSGDKron", + "DelayedPSGD": "PSGDKron", + "CachedPSGDKron": "PSGDKron", + "CachedDelayedPSGDKron": "PSGDKron", + "NewtonPSGDKron": "PSGDKron", + "DelayedPSGDLRA": "PSGDLRA", + "NewtonPSGDLRA": "PSGDLRA", +} + +_REMOVED_GROUP_KEYS = frozenset({"stochastic_schedule"}) +_REMOVED_META_KEYS = frozenset({"stochastic_schedule", "precond_rng"}) + @dataclass class TransformMapping: @@ -28,16 +81,37 @@ class TransformMapping: transform_idx: int +def _resolve_class_name(class_name: str) -> str: + return _CLASS_RENAMES.get(class_name, class_name) + + def _load_optimizer_class(qualified_name: str): if "." in qualified_name: module_name, class_name = qualified_name.rsplit(".", 1) else: module_name, class_name = "heavyball", qualified_name + class_name = _resolve_class_name(class_name) module = importlib.import_module(module_name) try: return getattr(module, class_name) except AttributeError as exc: - raise ValueError(f"Optimizer class '{qualified_name}' not found") from exc + raise ValueError(f"Optimizer class '{module_name}.{class_name}' not found") from exc + + +def _detect_version(state_dict: Dict[str, Any]) -> int: + nested = False + for entry in state_dict.get("state", {}).values(): + if isinstance(entry, dict): + nested = any(isinstance(v, dict) for v in entry.values()) + if not nested: + return 1 + break + + groups = state_dict.get("param_groups", []) + has_multi_tensor = any(isinstance(g, dict) and "multi_tensor" in g for g in groups) + if nested: + return 3 if has_multi_tensor else 2 + return 3 if has_multi_tensor else 1 def _guess_tensor_meta(state_entry: Dict[str, Any]) -> Tuple[Tuple[int, ...], torch.dtype]: @@ -69,8 +143,9 @@ def _build_dummy_parameters(state: Dict[int, Dict[str, Any]], param_groups: Sequ def _normalise_group_options(group: Dict[str, Any]) -> Dict[str, Any]: options: Dict[str, Any] = {} for key, value in group.items(): - if key == "params": + if key == "params" or key in _REMOVED_GROUP_KEYS: continue + key = "multi_tensor" if key == "foreach" else key if isinstance(value, list) and key in {"betas", "weight_decay_steps"}: options[key] = tuple(value) else: @@ -104,9 +179,15 @@ def walk(queue: Iterable[Any]): stack.append(current.fn) elif isinstance(current, functools.partial): # type: ignore[name-defined] stack.append(current.func) - elif isinstance(current, C.Branch): + elif isinstance(current, C.Parallel): for branch in current.branches: stack.extend(branch) + elif isinstance(current, C.Route): + for _, fns in current.routes: + if fns: + stack.extend(fns) + if current.default: + stack.extend(current.default) elif isinstance(current, (list, tuple)): stack.extend(current) @@ -180,7 +261,24 @@ def _migrate_single_state(entry: Dict[str, Any], mappings: List[TransformMapping return migrated -def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]: +def _migrate_v2_to_v3(state_dict: Dict[str, Any]) -> None: + for group in state_dict.get("param_groups", []): + if not isinstance(group, dict): + continue + if "foreach" in group: + group["multi_tensor"] = group.pop("foreach") + for key in _REMOVED_GROUP_KEYS: + group.pop(key, None) + + hb = state_dict.get("heavyball", {}) + for key in _REMOVED_META_KEYS: + hb.pop(key, None) + inner = hb.get("inner_group", {}) + for key in _REMOVED_META_KEYS: + inner.pop(key, None) + + +def _migrate_v1_state(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]: opt_cls = _load_optimizer_class(optimizer_class) optimizer, _ = _instantiate_optimizer(opt_cls, old_state) template = optimizer.state_dict() @@ -199,11 +297,19 @@ def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[ heavyball_meta = migrated.setdefault("heavyball", template.get("heavyball", {})) if "inner_group" not in heavyball_meta: - heavyball_meta["inner_group"] = {"stochastic_schedule": None} - if "stochastic_schedule" not in heavyball_meta: - heavyball_meta["stochastic_schedule"] = None - if "precond_rng" not in heavyball_meta: - heavyball_meta["precond_rng"] = pickle.dumps(random.Random(0x12312)) + heavyball_meta["inner_group"] = {} + return migrated + + +def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]: + version = _detect_version(old_state) + if version == 1: + migrated = _migrate_v1_state(old_state, optimizer_class) + elif version == 2: + migrated = dict(old_state) + else: + return dict(old_state) + _migrate_v2_to_v3(migrated) return migrated @@ -221,7 +327,7 @@ def _resolve_state_container(root: Dict[str, Any], key_path: Sequence[str]) -> D app = typer.Typer(help="Utilities for migrating HeavyBall optimizer checkpoints.") -@app.command(help="Migrate a HeavyBall optimizer state dict to the 2.0.0 layout.") +@app.command(help="Migrate a HeavyBall optimizer state dict to the current (3.x) layout.") def migrate( checkpoint: Path = typer.Argument( ..., @@ -234,7 +340,7 @@ def migrate( ), optimizer_class: str = typer.Argument( ..., - help="Optimizer class to instantiate (e.g., heavyball.ForeachAdamW)", + help="Optimizer class name (e.g., heavyball.AdamW). Old names like ForeachAdamW are resolved automatically.", ), state_key: str = typer.Option( "optimizer", diff --git a/test/test_ademamix.py b/test/test_ademamix.py index 847a290..49e53de 100644 --- a/test/test_ademamix.py +++ b/test/test_ademamix.py @@ -101,7 +101,7 @@ def test_ademamix_matches_reference_math(): alpha_warmup = 4 param = torch.nn.Parameter(initial.clone()) - optimizer = heavyball.ForeachAdEMAMix( + optimizer = heavyball.AdEMAMix( [param], lr=lr, betas=betas, @@ -110,7 +110,7 @@ def test_ademamix_matches_reference_math(): alpha=alpha, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, - foreach=False, + multi_tensor=False, ) for grad in grads: @@ -149,7 +149,7 @@ def test_soap_ademamix_projects_gradients_into_eigenbasis(): torch.manual_seed(7) param = torch.nn.Parameter(torch.randn(2, 2)) - optimizer = heavyball.ForeachSOAPAdEMAMix([param], lr=0.01, foreach=False) + optimizer = heavyball.SOAPAdEMAMix([param], lr=0.01, multi_tensor=False) # First call initializes the SOAP preconditioner state without applying an update. param.grad = torch.randn_like(param) diff --git a/test/test_chainable_cpu.py b/test/test_chainable_cpu.py index 124cbe4..ed2da33 100644 --- a/test/test_chainable_cpu.py +++ b/test/test_chainable_cpu.py @@ -36,7 +36,7 @@ def negate(_, __, update, ___, ____): def merge_fn(outputs): return [sum(vals) / len(vals) for vals in zip(*outputs)] - branch = C.Branch([[double], [negate]], merge_fn) + branch = C.Parallel([[double], [negate]], merge_fn) update = [torch.ones(2)] grad = [torch.ones(2)] @@ -70,46 +70,37 @@ def state_fn(_x): # Optimizers whose chains are purely elementwise must NOT need gather _EXPECT_NO_GATHER = { "SGD", - "ForeachAdamW", - "ForeachNAdam", - "ForeachAdEMAMix", + "AdamW", + "NAdam", + "AdEMAMix", "UnscaledAdamW", - "ForeachAdamC", - "ForeachRMSprop", - "ForeachSFAdamW", - "ForeachADOPT", - "ForeachLaProp", - "PaLMForeachSFAdamW", + "AdamC", + "RMSprop", + "SFAdamW", + "ADOPT", + "LaProp", } # Optimizers whose chains use shape-dependent or global-reduction ops must need gather _EXPECT_GATHER = { - "ForeachSOAP", - "ForeachSOAPNAdam", - "ForeachSOAPAdEMAMix", - "ForeachSOLP", - "ForeachMuon", + "SOAP", + "SOAPNAdam", + "SOAPAdEMAMix", + "SOLP", + "Muon", "MuonLaProp", "OrthoLaProp", "LaPropOrtho", - "ForeachPSGDKron", - "ForeachPurePSGD", - "ForeachCachedPSGDKron", - "ForeachDelayedPSGD", - "ForeachCachedDelayedPSGDKron", - "ForeachCachedNewtonPSGD", - "NewtonHybrid2PSGDKron", - "ForeachPSGDLRA", - "ForeachDelayedPSGDLRA", - "ForeachNewtonPSGDLRA", - "NewtonHybrid2PSGDLRA", + "PSGDKron", + "LATHER", + "PSGDLRA", + "PSGDPRO", "SUDSAdamW", "Scion", - "ForeachSignLaProp", + "SignLaProp", "MSAMLaProp", - "PaLMForeachSOAP", - "PrecondScheduleForeachSOAP", - "PrecondSchedulePaLMForeachSOAP", + "HyperBallAdamW", + "MuonAdamW", } _SKIP_INSTANTIATE = {"SplitOpt", "SAMWrapper"} @@ -120,7 +111,7 @@ def state_fn(_x): @pytest.mark.parametrize("opt_name", _ALL_OPTS) def test_needs_gather_flag(opt_name): params = [torch.nn.Parameter(torch.randn(4, 4))] - extra = {"max_lr": 0.0025} if opt_name == "ForeachAdamC" else {} + extra = {"max_lr": 0.0025} if opt_name == "AdamC" else {} opt = getattr(heavyball, opt_name)(params, lr=1e-3, **extra) if opt_name in _EXPECT_NO_GATHER: assert not opt._needs_gather, f"{opt_name} should be elementwise (no gather needed)" diff --git a/test/test_compile_step.py b/test/test_compile_step.py new file mode 100644 index 0000000..ce0a8e5 --- /dev/null +++ b/test/test_compile_step.py @@ -0,0 +1,114 @@ +import inspect + +import pytest +import torch + +import heavyball +from heavyball.chainable import ChainOpt, WarmupGuard, _walk_fns + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +EXTRA_KWARGS = { + "AdamC": {"max_lr": 0.01}, +} + + +def _optimizer_params(): + seen = set() + params = [] + for name in heavyball.__all__: + if not hasattr(heavyball, name): + continue + obj = getattr(heavyball, name) + if not inspect.isclass(obj): + continue + if not issubclass(obj, torch.optim.Optimizer): + continue + ident = id(obj) + if ident in seen: + continue + seen.add(ident) + if name == "SplitOpt": + params.append( + pytest.param(name, obj, id=name, marks=pytest.mark.skip(reason="SplitOpt requires dict specs")) + ) + continue + params.append(pytest.param(name, obj, id=name)) + return params + + +def _make_model(): + return torch.nn.Sequential( + torch.nn.Linear(8, 16), + torch.nn.Tanh(), + torch.nn.Linear(16, 4), + ).to(DEVICE) + + +def _run_steps(model, optimizer, n=5, seed=0xDEADBEEF): + torch.manual_seed(seed) + for _ in range(n): + + def closure(): + optimizer.zero_grad(set_to_none=True) + data = torch.randn(4, 8, device=DEVICE) + target = torch.randn(4, 4, device=DEVICE) + loss = torch.nn.functional.mse_loss(model(data), target) + loss.backward() + return loss + + optimizer.step(closure) + + +@pytest.mark.parametrize("opt_name,opt_cls", _optimizer_params()) +def test_compile_step_matches_eager(opt_name, opt_cls): + """compile_step=True must produce identical parameters to compile_step=False.""" + sig = inspect.signature(opt_cls.__init__) + if "compile_step" not in sig.parameters: + pytest.skip("optimizer does not accept compile_step") + + kwargs = dict(EXTRA_KWARGS.get(opt_name, {})) + + torch.manual_seed(0xDEADBEEF) + model_ref = _make_model() + model_test = _make_model() + model_test.load_state_dict(model_ref.state_dict()) + + opt_ref = opt_cls(model_ref.parameters(), compile_step=False, **kwargs) + opt_test = opt_cls(model_test.parameters(), compile_step=True, **kwargs) + + _run_steps(model_ref, opt_ref) + _run_steps(model_test, opt_test) + + for p_ref, p_test in zip(model_ref.parameters(), model_test.parameters()): + diff = (p_ref.data - p_test.data).abs().max().item() + assert diff < 1e-4, f"compile_step diverged: max_diff={diff}" + + +def _max_warmup(opt): + return max((len(ft.warmup_fns) for ft in _walk_fns(opt.fns) if isinstance(ft, WarmupGuard)), default=0) + + +@pytest.mark.parametrize("opt_name,opt_cls", _optimizer_params()) +def test_needs_init_clears(opt_name, opt_cls): + """_needs_init must become False after max_warmup + 1 steps for all ChainOpt optimizers. + + Catches bugs where Route-based or warmup_guard-based optimizers permanently + force eager mode because different params accumulate different is_initialized + sets that never individually cover _transform_ids. + """ + if not issubclass(opt_cls, ChainOpt): + pytest.skip("not a ChainOpt") + + kwargs = dict(EXTRA_KWARGS.get(opt_name, {})) + model = _make_model() + opt = opt_cls(model.parameters(), **kwargs) + n = _max_warmup(opt) + 1 + + _run_steps(model, opt, n=n) + + for group in opt.param_groups: + state = [opt.state_(p) for p in group["params"]] + assert not opt._needs_init(state), ( + f"{opt_name}: _needs_init stuck True after {n} steps | compile_step will never engage" + ) diff --git a/test/test_cpu_features.py b/test/test_cpu_features.py index e241f98..0a8b746 100644 --- a/test/test_cpu_features.py +++ b/test/test_cpu_features.py @@ -44,9 +44,9 @@ def _make_batch( @pytest.mark.parametrize( "opt_name", [ - "ForeachSOAP", + "SOAP", "Muon", - "ForeachAdamW", + "AdamW", ], ) def test_selected_optimizers_run_on_cpu(opt_name: str) -> None: @@ -92,8 +92,8 @@ def test_mars_flag_changes_behavior() -> None: model_a, data, target = _make_batch() model_b = deepcopy(model_a) - opt_a = heavyball.ForeachAdamW(model_a.parameters(), mars=False, warmup_steps=0) - opt_b = heavyball.ForeachAdamW(model_b.parameters(), mars=True, warmup_steps=0) + opt_a = heavyball.AdamW(model_a.parameters(), mars=False, warmup_steps=0) + opt_b = heavyball.AdamW(model_b.parameters(), mars=True, warmup_steps=0) init = [param.detach().clone() for param in model_a.parameters()] @@ -112,7 +112,7 @@ def test_mars_flag_changes_behavior() -> None: def test_sam_wrapper_requires_closure() -> None: model = nn.Linear(4, 2) - base = heavyball.ForeachAdamW(model.parameters()) + base = heavyball.AdamW(model.parameters()) wrapper = heavyball.SAMWrapper(model.parameters(), wrapped_optimizer=base) with pytest.raises(ValueError): @@ -132,3 +132,83 @@ def closure(): after = [param.detach() for param in model.parameters()] diff = torch.cat([(a - b).reshape(-1) for a, b in zip(after, before, strict=True)]) assert diff.norm().item() > 0.0 + + +def test_multiple_param_groups_keep_updating() -> None: + p1 = nn.Parameter(torch.zeros(())) + p2 = nn.Parameter(torch.zeros(())) + opt = heavyball.SGD( + [ + {"params": [p1]}, + {"params": [p2]}, + ], + lr=0.1, + beta=0.0, + warmup_steps=0, + ) + + for _ in range(3): + p1.grad = torch.ones_like(p1) + p2.grad = torch.full_like(p2, 2.0) + opt.step() + opt.zero_grad(set_to_none=True) + + assert torch.allclose(p1.detach(), torch.tensor(-0.3)) + assert torch.allclose(p2.detach(), torch.tensor(-0.6)) + + +def test_group_step_does_not_reset_when_active_param_changes() -> None: + p1 = nn.Parameter(torch.zeros(())) + p2 = nn.Parameter(torch.zeros(())) + opt = heavyball.SGD([p1, p2], lr=0.1, beta=0.0, warmup_steps=3) + + p1.grad = torch.ones_like(p1) + opt.step() + opt.zero_grad(set_to_none=True) + + p2.grad = torch.ones_like(p2) + opt.step() + opt.zero_grad(set_to_none=True) + + assert p1.item() == pytest.approx(-0.025) + assert p2.item() == pytest.approx(-0.05) + + +def test_string_clipping_shorthands_match_public_api() -> None: + model, data, target = _make_batch() + opt = heavyball.SGD( + model.parameters(), + lr=1e-3, + beta=0.0, + gradient_clipping="l2_clip_", + update_clipping="trust_region_clip_", + ) + + loss = _train_once(opt, model, data, target, steps=2) + assert torch.isfinite(torch.tensor(loss)) + + +@pytest.mark.parametrize("opt_cls", [heavyball.SFAdamW, heavyball.MSAMLaProp]) +def test_mode_switches_are_idempotent(opt_cls) -> None: + p = nn.Parameter(torch.tensor([1.0, -1.0])) + opt = opt_cls([p], lr=1e-2) + + p.grad = torch.ones_like(p) + opt.step() + opt.zero_grad(set_to_none=True) + + opt.eval() + eval_once = p.detach().clone() + opt.eval() + eval_twice = p.detach().clone() + + assert torch.allclose(eval_once, eval_twice) + assert opt.param_groups[0]["train_mode"] is False + + opt.train() + train_once = p.detach().clone() + opt.train() + train_twice = p.detach().clone() + + assert torch.allclose(train_once, train_twice) + assert opt.param_groups[0]["train_mode"] is True diff --git a/test/test_distributed.py b/test/test_distributed.py index 1b60e7b..d96d9d1 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -17,14 +17,13 @@ pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), ] -_EXTRA_KWARGS = {"ForeachAdamC": {"max_lr": 0.0025}} +_EXTRA_KWARGS = {"AdamC": {"max_lr": 0.0025}} _MODEL_SEED = 42 _DATA_SEED = 0xABCD # LRA builds one preconditioner over all grads, under FSDP each rank only has a subset _FSDP_SKIP = { - "ForeachPSGDLRA": "LRA preconditioner scope differs under FSDP", - "ForeachDelayedPSGDLRA": "LRA preconditioner scope differs under FSDP", + "PSGDLRA": "LRA preconditioner scope differs under FSDP", } # torch.compile(dynamic=False) specializes on list length → different kernels per rank @@ -39,20 +38,20 @@ _INTEGRATION_OPTS = [ n for n in [ - "ForeachAdamW", - "ForeachSOAP", - "ForeachMuon", - "ForeachPSGDKron", - "ForeachPurePSGD", + "AdamW", + "SOAP", + "Muon", + "PSGDKron", "Scion", - "ForeachLaProp", + "LaProp", "MuonLaProp", - "ForeachSOAPNAdam", - "ForeachCachedPSGDKron", + "SOAPNAdam", ] if n in REPRESENTATIVE_OPTS and n not in _FSDP_SKIP ] +_OWNER_EDGE_OPTS = [n for n in ["Muon"] if n in REPRESENTATIVE_OPTS and n not in _FSDP_SKIP] + def _set_cache(cache_dir, compile_mode="default"): os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir @@ -104,6 +103,14 @@ def _make_integration_model(): ) +def _make_owner_edge_model(): + return nn.Sequential( + nn.Linear(32, 32, bias=False), + nn.ReLU(), + nn.Linear(32, 1, bias=False), + ) + + def _make_data(dim=32, n=4): torch.manual_seed(_DATA_SEED) return [torch.randn(4, dim, device="cuda") for _ in range(n)] @@ -121,6 +128,10 @@ def _make_integration_data(): return _make_data(64, 8) +def _make_owner_edge_data(): + return _make_data(32, 6) + + def _train(model, opt, data): for x in data: model(x).mean().backward() @@ -266,7 +277,7 @@ def _assert_close(ref, result, label, rtol=0, atol=0): assert False, f"{label}: param {i} diverged beyond 1 ULP ({n} elements, max |diff|={worst:.2e})" -def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label): +def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label, world_size=2, tol=None): cache_dir = tempfile.mkdtemp(prefix=f"hb_{label}_{opt_name}_") ref_path = str(tmp_path / "ref.pt") mp.spawn(_ref_worker, args=(opt_name, ref_path, cache_dir, None, model_fn, data_fn), nprocs=1, join=True) @@ -274,12 +285,14 @@ def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label): result_path = str(tmp_path / "result.pt") mp.spawn( _fsdp_worker, - args=(2, str(tmp_path / "store"), opt_name, result_path, cache_dir, None, model_fn, data_fn), - nprocs=2, + args=(world_size, str(tmp_path / "store"), opt_name, result_path, cache_dir, None, model_fn, data_fn), + nprocs=world_size, join=True, ) - tol = dict(rtol=1e-2, atol=1e-4) if opt_name in _FSDP_PSGD else {} - _assert_close(ref, torch.load(result_path, weights_only=True), f"{label}/{opt_name}", **tol) + base_tol = dict(rtol=1e-2, atol=1e-4) if opt_name in _FSDP_PSGD else {} + if tol is not None: + base_tol.update({k: max(base_tol.get(k, 0), v) for k, v in tol.items()}) + _assert_close(ref, torch.load(result_path, weights_only=True), f"{label}/{opt_name}", **base_tol) @pytest.mark.parametrize("opt_name", REPRESENTATIVE_OPTS) @@ -339,3 +352,16 @@ def test_fsdp_misaligned(opt_name, tmp_path): @pytest.mark.parametrize("opt_name", _INTEGRATION_OPTS) def test_fsdp_integration(opt_name, tmp_path): _run_fsdp_test(opt_name, tmp_path, _make_integration_model, _make_integration_data, "FSDP-integ") + + +@pytest.mark.parametrize("opt_name", _OWNER_EDGE_OPTS) +def test_fsdp_owner_edge(opt_name, tmp_path): + _run_fsdp_test( + opt_name, + tmp_path, + _make_owner_edge_model, + _make_owner_edge_data, + "FSDP-owner3", + world_size=3, + tol=dict(atol=5e-8), + ) diff --git a/test/test_ecc.py b/test/test_ecc.py index 97c0eab..45c4758 100644 --- a/test/test_ecc.py +++ b/test/test_ecc.py @@ -14,13 +14,13 @@ ULP_MODES = ["bf16+8", "bf16+16", "fp16+8", "fp16+16"] _OPTIMIZERS = [ - (heavyball.ForeachAdamW, 5e-2, {}), - (heavyball.ForeachADOPT, 5e-2, {}), - (heavyball.ForeachNAdam, 1e-2, {}), - (heavyball.ForeachLaProp, 5e-2, {}), - (heavyball.ForeachAdEMAMix, 5e-2, {"betas": (0.9, 0.999, 0.9999)}), - (heavyball.ForeachRMSprop, 1e-2, {}), - (heavyball.PaLMForeachSFAdamW, 1e-2, {}), + (heavyball.AdamW, 5e-2, {}), + (heavyball.ADOPT, 5e-2, {}), + (heavyball.NAdam, 1e-2, {}), + (heavyball.LaProp, 5e-2, {}), + (heavyball.AdEMAMix, 5e-2, {"betas": (0.9, 0.999, 0.9999)}), + (heavyball.RMSprop, 1e-2, {}), + (heavyball.SFAdamW, 1e-2, {}), ] @@ -125,12 +125,12 @@ def test_ecc_convergence(opt_cls, lr, extra_kw, mode): def test_param_ecc_convergence(combined): set_torch() data, target = _problem() - m0, o0 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2) + m0, o0 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2) losses_base = _train(m0, o0, data, target, 200) kw = {"param_ecc": "bf16+8"} if combined: kw["ecc"] = "bf16+8" - m1, o1 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, **kw) + m1, o1 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, **kw) losses_ecc = _train(m1, o1, data, target, 200) p = list(m1.parameters())[0] assert p.dtype == torch.bfloat16 @@ -148,7 +148,7 @@ def test_state_layout_and_invariants(mode): cfg = ECCConfig(mode) torch.manual_seed(42) model = nn.Linear(32, 16, bias=False, device="cuda") - opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, ecc=mode) + opt = heavyball.AdamW(model.parameters(), lr=1e-2, ecc=mode) x = torch.randn(4, 32, device="cuda") for _ in range(3): model(x).sum().backward() @@ -174,7 +174,7 @@ def test_ademamix_three_vars(): set_torch() torch.manual_seed(42) model = nn.Linear(64, 32, bias=False, device="cuda") - opt = heavyball.ForeachAdEMAMix(model.parameters(), lr=5e-2, betas=(0.9, 0.999, 0.9999), ecc="bf16+8") + opt = heavyball.AdEMAMix(model.parameters(), lr=5e-2, betas=(0.9, 0.999, 0.9999), ecc="bf16+8") x = torch.randn(4, 64, device="cuda") for _ in range(10): model(x).sum().backward() @@ -193,7 +193,7 @@ def test_ademamix_three_vars(): def test_combined_ecc_dtypes(): set_torch() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 32, 16, 1e-2, ecc="bf16+16", param_ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 32, 16, 1e-2, ecc="bf16+16", param_ecc="bf16+8") data, target = _problem(in_dim=32, out_dim=16, n=8) _train(m, o, data, target, 100) p = list(m.parameters())[0] @@ -211,7 +211,7 @@ def test_shapes_and_bias(): set_torch() torch.manual_seed(42) model = nn.Sequential(nn.Linear(32, 16, bias=True), nn.Linear(16, 4, bias=False)).cuda() - opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, ecc="bf16+8") + opt = heavyball.AdamW(model.parameters(), lr=1e-2, ecc="bf16+8") data, target = torch.randn(16, 32, device="cuda"), torch.randn(16, 4, device="cuda") losses = _train(model, opt, data, target, 50) for p in model.parameters(): @@ -225,9 +225,9 @@ def test_shapes_and_bias(): def test_foreach_false(): set_torch() data, target = _problem(in_dim=32, out_dim=4, n=16) - m_fe, o_fe = _model_opt(heavyball.ForeachAdamW, 32, 4, 1e-2, ecc="bf16+8", foreach=True) + m_fe, o_fe = _model_opt(heavyball.AdamW, 32, 4, 1e-2, ecc="bf16+8", multi_tensor=True) losses_fe = _train(m_fe, o_fe, data, target, 50) - m_nf, o_nf = _model_opt(heavyball.ForeachAdamW, 32, 4, 1e-2, ecc="bf16+8", foreach=False) + m_nf, o_nf = _model_opt(heavyball.AdamW, 32, 4, 1e-2, ecc="bf16+8", multi_tensor=False) losses_nf = _train(m_nf, o_nf, data, target, 50) assert losses_nf[-1] < losses_nf[0] * 0.5 assert 0.3 < losses_nf[-1] / max(losses_fe[-1], 1e-12) < 3.0 @@ -239,7 +239,7 @@ def test_param_groups(): set_torch() torch.manual_seed(42) m1, m2 = nn.Linear(16, 8, bias=False, device="cuda"), nn.Linear(8, 4, bias=False, device="cuda") - opt = heavyball.ForeachAdamW( + opt = heavyball.AdamW( [ {"params": m1.parameters(), "ecc": "bf16+8"}, {"params": m2.parameters()}, @@ -261,7 +261,7 @@ def test_zero_gradients(): set_torch() torch.manual_seed(42) p = nn.Parameter(torch.randn(16, 8, device="cuda")) - opt = heavyball.ForeachAdamW([p], lr=1e-2, ecc="bf16+8") + opt = heavyball.AdamW([p], lr=1e-2, ecc="bf16+8") for _ in range(10): p.grad = torch.zeros_like(p) opt.step() @@ -276,10 +276,10 @@ def test_state_save_restore(): set_torch() data, target = _problem() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, ecc="bf16+8") _train(m, o, data, target, 10) sd_opt, sd_model = deepcopy(o.state_dict()), deepcopy(m.state_dict()) - m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, ecc="bf16+8") + m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, ecc="bf16+8") m2.load_state_dict(sd_model) o2.load_state_dict(sd_opt) losses_after = _train(m2, o2, data, target, 10) @@ -326,9 +326,9 @@ def _measure_peak(cls, n, lr, ecc=None, param_ecc=None, steps=3): @pytest.mark.parametrize( "cls,lr", [ - (heavyball.ForeachAdamW, 1e-3), - (heavyball.PaLMForeachSFAdamW, 1e-2), - (heavyball.ForeachRMSprop, 1e-2), + (heavyball.AdamW, 1e-3), + (heavyball.SFAdamW, 1e-2), + (heavyball.RMSprop, 1e-2), ], ids=["AdamW", "SFAdamW", "RMSprop"], ) @@ -346,7 +346,7 @@ def test_ecc_peak_memory(cls, lr, mode): @pytest.mark.parametrize("combined", [False, True], ids=["param_only", "state+param"]) def test_param_ecc_peak_memory(combined): n = 2**24 - cls, lr = heavyball.PaLMForeachSFAdamW, 1e-2 + cls, lr = heavyball.SFAdamW, 1e-2 pre_base, peak_base = _measure_peak(cls, n, lr) ecc = "bf16+8" if combined else None pre_ecc, peak_ecc = _measure_peak(cls, n, lr, ecc=ecc, param_ecc="bf16+8") @@ -357,19 +357,19 @@ def test_param_ecc_peak_memory(combined): def test_ecc_live_path_nonzero_correction(): set_torch() data, target = _problem() - m, o = _model_opt(heavyball.ForeachAdamW, 16, 8, 5e-2, ecc="bf16+8") + m, o = _model_opt(heavyball.AdamW, 16, 8, 5e-2, ecc="bf16+8") _train(m, o, data, target, 20) p = list(m.parameters())[0] st, ecc_keys = _ecc_keys(o, p) for ek in ecc_keys: - assert st[ek].any(), f"ECC correction '{ek}' is all zeros — stochastic_round_ likely mutating source" + assert st[ek].any(), f"ECC correction '{ek}' is all zeros - stochastic_round_ likely mutating source" del m, o clean() def test_lerp_returns_valid_fp32(): set_torch() - m, o = _model_opt(heavyball.ForeachAdamW, 16, 8, 5e-2, ecc="bf16+8") + m, o = _model_opt(heavyball.AdamW, 16, 8, 5e-2, ecc="bf16+8") data, target = _problem() _train(m, o, data, target, 10) p = list(m.parameters())[0] @@ -398,7 +398,7 @@ def test_param_ecc_merge_dims(): nn.Flatten(), nn.Linear(64, 8, bias=False), ).cuda() - opt = heavyball.ForeachSOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8") + opt = heavyball.SOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8") conv_p = list(model.parameters())[0] assert conv_p.dtype == torch.bfloat16 st = _flat_state(opt, conv_p) @@ -417,7 +417,7 @@ def test_param_ecc_merge_dims(): def test_param_ecc_dtype_at_construction(): set_torch() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") p = list(m.parameters())[0] assert p.dtype == torch.bfloat16, "param should be bf16 immediately after construction" st = _flat_state(o, p) @@ -431,11 +431,11 @@ def test_param_ecc_save_restore(): set_torch() data, target = _problem() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") _train(m, o, data, target, 10) sd_opt, sd_model = deepcopy(o.state_dict()), deepcopy(m.state_dict()) - m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") m2.load_state_dict(sd_model) o2.load_state_dict(sd_opt) p2 = list(m2.parameters())[0] @@ -452,7 +452,7 @@ def test_param_ecc_partial_state_restore(): set_torch() data, target = _problem() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") _train(m, o, data, target, 10) sd_opt = deepcopy(o.state_dict()) sd_model = deepcopy(m.state_dict()) @@ -460,7 +460,7 @@ def test_param_ecc_partial_state_restore(): for idx_state in param_state.values(): if isinstance(idx_state, dict): idx_state.pop("param::ecc", None) - m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") m2.load_state_dict(sd_model) o2.load_state_dict(sd_opt) st2 = _flat_state(o2, list(m2.parameters())[0]) @@ -478,12 +478,12 @@ def test_param_ecc_empty_state_restore(): set_torch() data, target = _problem() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") _train(m, o, data, target, 10) sd_opt = deepcopy(o.state_dict()) sd_model = deepcopy(m.state_dict()) sd_opt["state"] = {} - m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") m2.load_state_dict(sd_model) o2.load_state_dict(sd_opt) p2 = list(m2.parameters())[0] @@ -508,7 +508,7 @@ def test_param_ecc_merged_view_partial_restore(): nn.Flatten(), nn.Linear(64, 8, bias=False), ).cuda() - opt = heavyball.ForeachSOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8") + opt = heavyball.SOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8") data = torch.randn(4, 3, 8, 8, device="cuda") target = torch.randn(4, 8, device="cuda") conv_p = list(model.parameters())[0] @@ -529,7 +529,7 @@ def test_param_ecc_merged_view_partial_restore(): nn.Flatten(), nn.Linear(64, 8, bias=False), ).cuda() - opt2 = heavyball.ForeachSOAP(model2.parameters(), lr=1e-3, param_ecc="bf16+8") + opt2 = heavyball.SOAP(model2.parameters(), lr=1e-3, param_ecc="bf16+8") model2.load_state_dict(sd_model) opt2.load_state_dict(sd_opt) conv_p2 = list(model2.parameters())[0] @@ -549,7 +549,7 @@ def test_param_ecc_merged_view_partial_restore(): def test_optimizer_kwargs_not_in_param_groups(): set_torch() p = torch.nn.Parameter(torch.randn(4, 4, device="cuda")) - o = heavyball.ForeachAdamW([p], lr=1e-3, compile_step=True, promote=True) + o = heavyball.AdamW([p], lr=1e-3, compile_step=True, promote=True) assert o.compile_step is True assert o.promote is True assert "compile_step" not in o.param_groups[0] @@ -564,7 +564,7 @@ def test_param_ecc_load_order_model_before_optimizer(): set_torch() data, target = _problem() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") _train(m, o, data, target, 10) sd_opt = deepcopy(o.state_dict()) sd_model = deepcopy(m.state_dict()) @@ -573,7 +573,7 @@ def test_param_ecc_load_order_model_before_optimizer(): if isinstance(idx_state, dict): idx_state.pop("param::ecc", None) # model-first: load model, then optimizer - m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") m2.load_state_dict(sd_model) o2.load_state_dict(sd_opt) p2 = list(m2.parameters())[0] @@ -594,7 +594,7 @@ def test_param_ecc_load_order_optimizer_before_model(): set_torch() data, target = _problem() - m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") _train(m, o, data, target, 10) sd_opt = deepcopy(o.state_dict()) sd_model = deepcopy(m.state_dict()) @@ -603,7 +603,7 @@ def test_param_ecc_load_order_optimizer_before_model(): if isinstance(idx_state, dict): idx_state.pop("param::ecc", None) # optimizer-first: load optimizer, then model - m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") + m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8") o2.load_state_dict(sd_opt) m2.load_state_dict(sd_model) p2 = list(m2.parameters())[0] diff --git a/test/test_foreach.py b/test/test_foreach.py index de3fa2a..06dd9f4 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -57,12 +57,12 @@ def test_foreach( losses = [[], []] for i in range(total_runs): - for foreach in [True, False]: - lss, pk = losses[int(foreach)], peaks[int(foreach)] + for multi_tensor in [True, False]: + lss, pk = losses[int(multi_tensor)], peaks[int(multi_tensor)] torch.manual_seed(0x2131290) model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda() - o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach) + o = get_optim(opt, model.parameters(), lr=1e-3, multi_tensor=multi_tensor) clean() @@ -92,17 +92,17 @@ def test_foreach( cutoff = warmup_runs * iterations losses = [loss_list[cutoff:] for loss_list in losses] - for peak_no_foreach, peak_foreach in zip(*peaks): - assert peak_no_foreach < peak_foreach + for peak_single, peak_multi in zip(*peaks): + assert peak_single < peak_multi - # no-foreach LRA is a different optimizer (per-parameter LRA vs global LRA), + # single-tensor LRA is a different optimizer (per-parameter LRA vs global LRA), # so we only check that both converge, not that they match. if "LRA" in opt.__name__: return - for loss_no_foreach, loss_foreach in zip(*losses): - if torch.isnan(loss_no_foreach) and torch.isnan(loss_foreach): + for loss_single, loss_multi in zip(*losses): + if torch.isnan(loss_single) and torch.isnan(loss_multi): continue # increase error tolerance for PSGD, as we have different RNGs -> expected differences - assert torch.allclose(loss_no_foreach, loss_foreach, rtol=0.01 if "PSGD" in opt.__name__ else 1e-5) + assert torch.allclose(loss_single, loss_multi, rtol=0.01 if "PSGD" in opt.__name__ else 1e-5) diff --git a/test/test_helpers_cpu.py b/test/test_helpers_cpu.py index bc5952c..5a7ed4e 100644 --- a/test/test_helpers_cpu.py +++ b/test/test_helpers_cpu.py @@ -3,6 +3,7 @@ import numpy as np import optuna import pandas as pd +import pytest import torch from optuna.distributions import FloatDistribution, IntDistribution from optuna.samplers import RandomSampler @@ -81,6 +82,22 @@ def _dummy_candidates(params, values, *_args): assert 0.0 <= suggestion["width"] <= 1.0 +def test_helper_samplers_reject_removed_compat_kwargs(): + search_space = {"width": FloatDistribution(0.0, 1.0)} + + with pytest.raises(TypeError): + helpers.BoTorchSampler(search_space, consider_running_trials=True) + + with pytest.raises(TypeError): + helpers.HEBOSampler(search_space, constant_liar=True) + + with pytest.raises(TypeError): + helpers.ImplicitNaturalGradientSampler(search_space, warn_independent_sampling=False) + + with pytest.raises(TypeError): + helpers.AutoSampler(search_space=search_space, constraints_func=lambda *_args: None) + + def test_hebo_sampler_observe_and_sample(monkeypatch): class DummyHEBO: def __init__(self, *_args, **_kwargs): diff --git a/test/test_memory_leak.py b/test/test_memory_leak.py index f7088dc..f6c6f50 100644 --- a/test/test_memory_leak.py +++ b/test/test_memory_leak.py @@ -34,7 +34,7 @@ def forward(self, x): def test_memory( - opt: str = "NewtonHybrid2PSGDKron", + opt: str = "PSGDKron", size: int = 64, depth: int = 2, mars: bool = False, diff --git a/test/test_merge.py b/test/test_merge.py index 90bafb2..d1a660f 100644 --- a/test/test_merge.py +++ b/test/test_merge.py @@ -18,7 +18,7 @@ def forward(self, inp): return self.weight.mean() * inp -@pytest.mark.parametrize("opt", ["ForeachPSGDKron"]) +@pytest.mark.parametrize("opt", ["PSGDKron"]) @pytest.mark.parametrize("size", [(16, 16, 16, 16), (4, 4, 4, 4), (512, 1, 128), (32128, 768)]) @pytest.mark.parametrize("merge,split", [(False, False), (True, False), (True, True)]) def test_merge( diff --git a/test/test_migrate_cli.py b/test/test_migrate_cli.py index 1554d09..0130305 100644 --- a/test/test_migrate_cli.py +++ b/test/test_migrate_cli.py @@ -1,5 +1,7 @@ import importlib.util import pathlib +import pickle +import random import sys import pytest @@ -14,165 +16,723 @@ sys.modules[MODULE_NAME] = migrate_script SPEC.loader.exec_module(migrate_script) # type: ignore[arg-type] +# --------------------------------------------------------------------------- +# Rename tables (exhaustive, mirrors _CLASS_RENAMES in the script) +# --------------------------------------------------------------------------- + +_DIRECT_RENAMES = [ + ("ForeachAdamW", "AdamW"), + ("ForeachNAdam", "NAdam"), + ("ForeachAdEMAMix", "AdEMAMix"), + ("ForeachAdamC", "AdamC"), + ("ForeachRMSprop", "RMSprop"), + ("ForeachSFAdamW", "SFAdamW"), + ("ForeachADOPT", "ADOPT"), + ("ForeachMuon", "Muon"), + ("ForeachLaProp", "LaProp"), + ("ForeachSOAP", "SOAP"), + ("ForeachSOAPNAdam", "SOAPNAdam"), + ("ForeachSOAPAdEMAMix", "SOAPAdEMAMix"), + ("ForeachSignLaProp", "SignLaProp"), + ("ForeachSOLP", "SOLP"), + ("ForeachPSGDKron", "PSGDKron"), + ("ForeachPSGDLRA", "PSGDLRA"), +] + +_DELETED_RENAMES = [ + ("PaLMForeachSFAdamW", "SFAdamW"), + ("PaLMForeachSOAP", "SOAP"), + ("PrecondScheduleForeachSOAP", "SOAP"), + ("PrecondSchedulePaLMForeachSOAP", "SOAP"), + ("ForeachPurePSGD", "PSGDKron"), + ("ForeachCachedPSGDKron", "PSGDKron"), + ("ForeachCachedDelayedPSGDKron", "PSGDKron"), + ("ForeachDelayedPSGD", "PSGDKron"), + ("ForeachCachedNewtonPSGD", "PSGDKron"), + ("NewtonHybrid2PSGDKron", "PSGDKron"), + ("ForeachDelayedPSGDLRA", "PSGDLRA"), + ("ForeachNewtonPSGDLRA", "PSGDLRA"), + ("NewtonHybrid2PSGDLRA", "PSGDLRA"), +] + +_ALL_RENAMES = _DIRECT_RENAMES + _DELETED_RENAMES + +# PSGDLRA uses a different constructor signature (beta not betas, rank param, etc.) +# so e2e tests using AdamW-style param_groups must exclude it +_LRA_TARGETS = {"PSGDLRA"} +_DIRECT_RENAMES_NO_LRA = [(o, n) for o, n in _DIRECT_RENAMES if n not in _LRA_TARGETS] +_DELETED_RENAMES_NO_LRA = [(o, n) for o, n in _DELETED_RENAMES if n not in _LRA_TARGETS] +_DIRECT_RENAMES_LRA = [(o, n) for o, n in _DIRECT_RENAMES if n in _LRA_TARGETS] +_DELETED_RENAMES_LRA = [(o, n) for o, n in _DELETED_RENAMES if n in _LRA_TARGETS] + +_PASSTHROUGH = ["AdamW", "NAdam", "PSGDKron", "SGD", "SOAP", "Muon", "ADOPT", "LaProp", "PSGDLRA", "SFAdamW"] + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _flat_state(*shapes): + return { + i: { + "update_by_adam_exp_avg": torch.ones(s), + "update_by_adam_exp_avg_sq": torch.full(s, 2.0), + "is_initialized": [0], + } + for i, s in enumerate(shapes) + } + + +def _nested_state(*shapes): + return {i: {0: {"key_0": torch.zeros(s), "is_initialized": [0]}} for i, s in enumerate(shapes)} + + +def _v1_group(pids): + return { + "params": pids, + "lr": 0.0025, + "betas": [0.9, 0.99], + "eps": 1e-8, + "weight_decay": 0.0, + "warmup_steps": 0, + "foreach": True, + "stochastic_schedule": False, + "storage_dtype": "float32", + "mars": False, + "caution": False, + "mars_gamma": 0.0025, + "gradient_clipping": "use_default", + "update_clipping": "use_default", + "palm": "use_default", + "beta2_scale": 0.8, + } + + +def _v2_group(pids): + return {**_v1_group(pids), "__class__": "heavyball.ForeachAdamW"} + + +def _v3_group(pids): + g = _v1_group(pids) + g.pop("foreach") + g.pop("stochastic_schedule") + g["multi_tensor"] = True + return g + + +def _v2_meta(): + return { + "inner_group": {"stochastic_schedule": None}, + "stochastic_schedule": None, + "precond_rng": pickle.dumps(random.Random(0x12312)), + "use_ema": False, + } + + +def _v3_meta(): + return {"inner_group": {}, "use_ema": False} + + +def _load_heavyball_fresh(): + package_root = pathlib.Path(__file__).resolve().parents[1] + heavyball_pkg = package_root / "heavyball" + saved = {n: sys.modules[n] for n in list(sys.modules) if n == "heavyball" or n.startswith("heavyball.")} + for n in list(sys.modules): + if n == "heavyball" or n.startswith("heavyball."): + sys.modules.pop(n) + spec = importlib.util.spec_from_file_location( + "heavyball", heavyball_pkg / "__init__.py", submodule_search_locations=[str(heavyball_pkg)] + ) + mod = importlib.util.module_from_spec(spec) + sys.modules["heavyball"] = mod + spec.loader.exec_module(mod) + return saved + + +def _restore_heavyball(saved): + for n in list(sys.modules): + if n == "heavyball" or n.startswith("heavyball."): + sys.modules.pop(n) + sys.modules.update(saved) + @pytest.fixture() def runner(): return CliRunner() -def test_cli_dry_run_updates_state(monkeypatch, runner, tmp_path): - checkpoint_path = tmp_path / "ckpt.pt" - checkpoint_path.touch() +# ==================================================================== +# _resolve_class_name +# ==================================================================== - state_container = {"state": {"initial": True}, "param_groups": ["group"]} - checkpoint = {"optimizer": state_container} - def fake_load(path, map_location=None): - return checkpoint +@pytest.mark.parametrize("old,new", _ALL_RENAMES) +def test_resolve_class_name_renames(old, new): + assert migrate_script._resolve_class_name(old) == new - def fake_migrate(state, _): - return {"state": {"migrated": True}, "param_groups": []} - def fail_save(*args, **kwargs): - pytest.fail("torch.save should not run during dry-run") +@pytest.mark.parametrize("name", _PASSTHROUGH) +def test_resolve_class_name_passthrough(name): + assert migrate_script._resolve_class_name(name) == name - monkeypatch.setattr(migrate_script.torch, "load", fake_load) - monkeypatch.setattr(migrate_script, "migrate_state_dict", fake_migrate) - monkeypatch.setattr(migrate_script.torch, "save", fail_save) - result = runner.invoke( - migrate_script.app, - [str(checkpoint_path), "heavyball.Mock", "--state-key", "optimizer", "--dry-run"], - ) +# ==================================================================== +# _load_optimizer_class +# ==================================================================== + + +@pytest.mark.parametrize("old,new", _ALL_RENAMES) +def test_load_optimizer_class_old_names(old, new): + import heavyball + + assert migrate_script._load_optimizer_class(f"heavyball.{old}") is getattr(heavyball, new) + + +@pytest.mark.parametrize("name", _PASSTHROUGH) +def test_load_optimizer_class_current_names(name): + import heavyball + + assert migrate_script._load_optimizer_class(name) is getattr(heavyball, name) + + +def test_load_optimizer_class_bare_name(): + import heavyball + + assert migrate_script._load_optimizer_class("AdamW") is heavyball.AdamW + + +def test_load_optimizer_class_invalid_raises(): + with pytest.raises(ValueError, match="not found"): + migrate_script._load_optimizer_class("heavyball.TotallyFakeOptimizer") + + +# ==================================================================== +# _detect_version +# ==================================================================== + +_DETECT_CASES = [ + ("flat_foreach", {"state": _flat_state((2,)), "param_groups": [{"foreach": True}]}, 1), + ("flat_multi_tensor", {"state": _flat_state((2,)), "param_groups": [{"multi_tensor": True}]}, 1), + ("flat_neither", {"state": _flat_state((2,)), "param_groups": [{}]}, 1), + ("flat_multi_group", {"state": _flat_state((2,), (3,)), "param_groups": [{"foreach": True}, {}]}, 1), + ("nested_foreach", {"state": _nested_state((2,)), "param_groups": [{"foreach": True}]}, 2), + ( + "nested_foreach_multi_group", + {"state": _nested_state((2,), (3,)), "param_groups": [{"foreach": True}, {"foreach": True}]}, + 2, + ), + ("nested_multi_tensor", {"state": _nested_state((2,)), "param_groups": [{"multi_tensor": True}]}, 3), + ( + "nested_multi_tensor_multi_group", + {"state": _nested_state((2,), (3,)), "param_groups": [{"multi_tensor": True}, {}]}, + 3, + ), + ("empty_state_multi_tensor", {"state": {}, "param_groups": [{"multi_tensor": True}]}, 3), + ("empty_state_foreach", {"state": {}, "param_groups": [{"foreach": True}]}, 1), + ("empty_state_bare", {"state": {}, "param_groups": [{}]}, 1), + ("empty_state_empty_groups", {"state": {}, "param_groups": []}, 1), +] + + +@pytest.mark.parametrize("name,sd,expected", _DETECT_CASES, ids=[c[0] for c in _DETECT_CASES]) +def test_detect_version(name, sd, expected): + assert migrate_script._detect_version(sd) == expected + + +# ==================================================================== +# _normalise_group_options +# ==================================================================== + +_NORM_CASES = [ + ("foreach_renamed", {"foreach": True, "lr": 0.01}, {"multi_tensor": True, "lr": 0.01}), + ("foreach_false", {"foreach": False, "lr": 0.01}, {"multi_tensor": False, "lr": 0.01}), + ("multi_tensor_passthrough", {"multi_tensor": False, "lr": 0.01}, {"multi_tensor": False, "lr": 0.01}), + ("stochastic_stripped", {"foreach": True, "stochastic_schedule": False}, {"multi_tensor": True}), + ("betas_tuple", {"betas": [0.9, 0.999]}, {"betas": (0.9, 0.999)}), + ("weight_decay_steps_tuple", {"weight_decay_steps": [100, 200]}, {"weight_decay_steps": (100, 200)}), + ("params_stripped", {"params": [0, 1], "lr": 0.01}, {"lr": 0.01}), + ("empty_group", {}, {}), + ("params_only", {"params": [0]}, {}), + ("all_removed", {"params": [0], "stochastic_schedule": None}, {}), + ( + "kitchen_sink", + {"params": [0, 1], "foreach": True, "stochastic_schedule": True, "lr": 0.01, "betas": [0.9, 0.99]}, + {"multi_tensor": True, "lr": 0.01, "betas": (0.9, 0.99)}, + ), +] + + +@pytest.mark.parametrize("name,group,expected", _NORM_CASES, ids=[c[0] for c in _NORM_CASES]) +def test_normalise_group_options(name, group, expected): + assert migrate_script._normalise_group_options(group) == expected + + +# ==================================================================== +# _ensure_set +# ==================================================================== + + +@pytest.mark.parametrize( + "inp,expected", + [ + ({1, 2}, {1, 2}), + ([3, 4], {3, 4}), + ((5,), {5}), + (None, set()), + (7, {7}), + ([], set()), + ], + ids=["set", "list", "tuple", "none", "scalar", "empty_list"], +) +def test_ensure_set(inp, expected): + assert migrate_script._ensure_set(inp) == expected + + +# ==================================================================== +# _guess_tensor_meta +# ==================================================================== + + +@pytest.mark.parametrize( + "entry,shape,dtype", + [ + ({"a": torch.zeros(3, 4, dtype=torch.float16)}, (3, 4), torch.float16), + ({"a": [torch.zeros(2, dtype=torch.bfloat16)]}, (2,), torch.bfloat16), + ({"a": "not_a_tensor", "b": torch.ones(5)}, (5,), torch.float32), + ({}, (1,), torch.float32), + ({"a": 42}, (1,), torch.float32), + ], + ids=["tensor", "tensor_list", "mixed_finds_tensor", "empty", "no_tensors"], +) +def test_guess_tensor_meta(entry, shape, dtype): + s, d = migrate_script._guess_tensor_meta(entry) + assert s == shape + assert d == dtype + + +# ==================================================================== +# _resolve_state_container +# ==================================================================== + + +def test_resolve_state_container_single_key(): + root = {"optimizer": {"state": {}, "param_groups": []}} + assert migrate_script._resolve_state_container(root, ["optimizer"]) is root["optimizer"] + + +def test_resolve_state_container_nested_key(): + inner = {"state": {}, "param_groups": []} + root = {"model": {"training": {"opt": inner}}} + assert migrate_script._resolve_state_container(root, ["model", "training", "opt"]) is inner + + +def test_resolve_state_container_missing_key(): + with pytest.raises(KeyError, match="not found"): + migrate_script._resolve_state_container({"a": {}}, ["a", "b"]) + + +def test_resolve_state_container_not_optimizer(): + with pytest.raises(ValueError, match="not an optimizer state dict"): + migrate_script._resolve_state_container({"opt": {"just_data": 1}}, ["opt"]) + + +# ==================================================================== +# _migrate_v2_to_v3 +# ==================================================================== + + +@pytest.mark.parametrize( + "meta_keys", + [ + {"stochastic_schedule": None, "precond_rng": pickle.dumps(random.Random(0))}, + {"stochastic_schedule": None}, + {"precond_rng": b"rng"}, + {}, + ], + ids=["both", "stochastic_only", "precond_rng_only", "clean"], +) +def test_migrate_v2_to_v3_strips_stale_meta(meta_keys): + sd = { + "state": _nested_state((3,)), + "param_groups": [{"params": [0], "foreach": True, "stochastic_schedule": False}], + "heavyball": {"inner_group": {"stochastic_schedule": None}, "use_ema": False, **meta_keys}, + } + migrate_script._migrate_v2_to_v3(sd) + group = sd["param_groups"][0] + assert group["multi_tensor"] is True + assert "foreach" not in group + assert "stochastic_schedule" not in group + hb = sd["heavyball"] + for key in ("stochastic_schedule", "precond_rng"): + assert key not in hb + assert key not in hb["inner_group"] + assert hb["use_ema"] is False + + +def test_migrate_v2_to_v3_multi_group(): + sd = { + "state": _nested_state((2,), (3,)), + "param_groups": [ + {"params": [0], "foreach": True, "stochastic_schedule": False}, + {"params": [1], "foreach": False, "stochastic_schedule": True, "lr": 0.1}, + ], + "heavyball": _v2_meta(), + } + migrate_script._migrate_v2_to_v3(sd) + assert sd["param_groups"][0]["multi_tensor"] is True + assert sd["param_groups"][1]["multi_tensor"] is False + assert sd["param_groups"][1]["lr"] == 0.1 + for g in sd["param_groups"]: + assert "foreach" not in g + assert "stochastic_schedule" not in g + + +def test_migrate_v2_to_v3_preserves_non_stale_meta(): + sd = { + "state": _nested_state((2,)), + "param_groups": [{"params": [0], "foreach": True}], + "heavyball": {**_v2_meta(), "use_ema": True, "ema_decay": 0.005, "compile_step": True}, + } + migrate_script._migrate_v2_to_v3(sd) + assert sd["heavyball"]["use_ema"] is True + assert sd["heavyball"]["ema_decay"] == 0.005 + assert sd["heavyball"]["compile_step"] is True + + +def test_migrate_v2_to_v3_no_heavyball_key(): + sd = { + "state": _nested_state((2,)), + "param_groups": [{"params": [0], "foreach": True}], + } + migrate_script._migrate_v2_to_v3(sd) + assert sd["param_groups"][0]["multi_tensor"] is True + + +# ==================================================================== +# _migrate_single_state +# ==================================================================== + + +@pytest.mark.parametrize("n_views", [1, 2, 5]) +def test_migrate_single_state_rewrites_keys(n_views): + mappings = [ + migrate_script.TransformMapping("exp_avg", "exp_avg_0", 0), + migrate_script.TransformMapping("exp_avg_sq", "exp_avg_sq_0", 0), + ] + entry = { + "exp_avg": [torch.ones(2)] * n_views if n_views > 1 else torch.ones(2), + "exp_avg_sq": [torch.full((2,), 2.0)] * n_views if n_views > 1 else torch.full((2,), 2.0), + "is_initialized": [0], + } + migrated = migrate_script._migrate_single_state(entry, mappings) + for view_idx in range(n_views): + bucket = migrated[view_idx] + assert "exp_avg_0" in bucket + assert "exp_avg_sq_0" in bucket + assert "exp_avg" not in bucket + assert "exp_avg_sq" not in bucket + assert 0 in bucket["is_initialized"] + + +def test_migrate_single_state_empty_mappings(): + entry = {"some_key": torch.ones(3), "is_initialized": {0, 1}} + migrated = migrate_script._migrate_single_state(entry, []) + assert "some_key" in migrated[0] + assert set(migrated[0]["is_initialized"]) == {0, 1} + + +def test_migrate_single_state_multiple_transforms(): + mappings = [ + migrate_script.TransformMapping("exp_avg", "update_by_adam_exp_avg_0", 0), + migrate_script.TransformMapping("exp_avg_sq", "update_by_adam_exp_avg_sq_0", 0), + migrate_script.TransformMapping("momentum", "heavyball_momentum_momentum_1", 1), + ] + entry = { + "exp_avg": torch.ones(4), + "exp_avg_sq": torch.full((4,), 2.0), + "momentum": torch.full((4,), 0.5), + "is_initialized": [0, 1], + } + migrated = migrate_script._migrate_single_state(entry, mappings) + bucket = migrated[0] + assert "update_by_adam_exp_avg_0" in bucket + assert "update_by_adam_exp_avg_sq_0" in bucket + assert "heavyball_momentum_momentum_1" in bucket + assert set(bucket["is_initialized"]) == {0, 1} + + +def test_migrate_single_state_preserves_values(): + t = torch.arange(6, dtype=torch.float32) + mappings = [migrate_script.TransformMapping("old", "new_0", 0)] + migrated = migrate_script._migrate_single_state({"old": t, "is_initialized": [0]}, mappings) + assert torch.equal(migrated[0]["new_0"], t) + +# ==================================================================== +# Full migrate_state_dict +# ==================================================================== + + +def test_migrate_v2_state_dict(): + sd = { + "state": _nested_state((4,)), + "param_groups": [_v2_group([0])], + "heavyball": _v2_meta(), + } + migrated = migrate_script.migrate_state_dict(sd, "heavyball.ForeachAdamW") + group = migrated["param_groups"][0] + assert group["multi_tensor"] is True + assert "foreach" not in group + assert "stochastic_schedule" not in group + hb = migrated.get("heavyball", {}) + assert "stochastic_schedule" not in hb + assert "precond_rng" not in hb + + +def test_migrate_v3_is_noop(): + sd = { + "state": _nested_state((4,)), + "param_groups": [_v3_group([0])], + "heavyball": _v3_meta(), + } + original_key0 = sd["state"][0][0]["key_0"].clone() + migrated = migrate_script.migrate_state_dict(sd, "heavyball.AdamW") + assert migrated is not sd + assert torch.equal(migrated["state"][0][0]["key_0"], original_key0) + assert migrated["param_groups"][0]["multi_tensor"] is True + + +def test_migrate_v2_with_old_class_name(): + sd = { + "state": _nested_state((4,)), + "param_groups": [_v2_group([0])], + "heavyball": _v2_meta(), + } + migrated = migrate_script.migrate_state_dict(sd, "ForeachAdamW") + assert migrated["param_groups"][0]["multi_tensor"] is True + + +def test_migrate_v2_multi_param(): + sd = { + "state": _nested_state((4, 4), (8,), (2, 2, 2)), + "param_groups": [_v2_group([0, 1, 2])], + "heavyball": _v2_meta(), + } + migrated = migrate_script.migrate_state_dict(sd, "AdamW") + for pid in (0, 1, 2): + assert pid in migrated["state"] + assert 0 in migrated["state"][pid] + + +# ==================================================================== +# CLI (typer) | mocked +# ==================================================================== + + +def test_cli_dry_run(monkeypatch, runner, tmp_path): + ckpt = tmp_path / "ckpt.pt" + ckpt.touch() + container = {"state": {}, "param_groups": []} + checkpoint = {"optimizer": container} + + monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: checkpoint) + monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {"ok": True}, "param_groups": []}) + monkeypatch.setattr(migrate_script.torch, "save", lambda *a, **kw: pytest.fail("save during dry-run")) + + result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--dry-run"]) assert result.exit_code == 0 assert "Dry run" in result.stdout - assert state_container == {"state": {"migrated": True}, "param_groups": []} + assert container["state"] == {"ok": True} def test_cli_writes_output(monkeypatch, runner, tmp_path): - checkpoint_path = tmp_path / "source.pt" - checkpoint_path.touch() - output_path = tmp_path / "out.pt" - - state_container = {"state": {"initial": True}, "param_groups": ["group"]} - checkpoint = {"optimizer": state_container} - migrated = {"state": {"migrated": True}, "param_groups": []} + ckpt = tmp_path / "source.pt" + ckpt.touch() + out = tmp_path / "out.pt" + migrated = {"state": {"done": True}, "param_groups": []} + checkpoint = {"optimizer": {"state": {}, "param_groups": []}} + saved = {} - def fake_load(path, map_location=None): - return checkpoint + monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: checkpoint) + monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: migrated) + monkeypatch.setattr(migrate_script.torch, "save", lambda obj, path: saved.update(obj=obj, path=pathlib.Path(path))) - def fake_migrate(state, _): - return migrated + result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--output", str(out)]) + assert result.exit_code == 0 + assert saved["path"] == out + assert saved["obj"]["optimizer"] == migrated - monkeypatch.setattr(migrate_script.torch, "load", fake_load) - monkeypatch.setattr(migrate_script, "migrate_state_dict", fake_migrate) +def test_cli_overwrites_input_by_default(monkeypatch, runner, tmp_path): + ckpt = tmp_path / "ckpt.pt" + ckpt.touch() saved = {} - def fake_save(obj, path): - saved["obj"] = obj - saved["path"] = pathlib.Path(path) + monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: {"optimizer": {"state": {}, "param_groups": []}}) + monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {}, "param_groups": []}) + monkeypatch.setattr(migrate_script.torch, "save", lambda obj, path: saved.update(path=pathlib.Path(path))) + + result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock"]) + assert result.exit_code == 0 + assert saved["path"] == ckpt - monkeypatch.setattr(migrate_script.torch, "save", fake_save) - result = runner.invoke( - migrate_script.app, - [str(checkpoint_path), "heavyball.Mock", "--output", str(output_path)], - ) +@pytest.mark.parametrize("key", ["optimizer", "model.optimizer", "a.b.c"]) +def test_cli_state_key_parsing(monkeypatch, runner, tmp_path, key): + ckpt = tmp_path / "ckpt.pt" + ckpt.touch() + parts = key.split(".") + inner = {"state": {}, "param_groups": []} + root = inner + for p in reversed(parts): + root = {p: root} + monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: root) + monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {}, "param_groups": []}) + monkeypatch.setattr(migrate_script.torch, "save", lambda *a, **kw: None) + + result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--state-key", key]) assert result.exit_code == 0 - assert saved["path"] == output_path - assert saved["obj"]["optimizer"] == migrated -def test_cli_migrates_legacy_checkpoint(runner, tmp_path): - package_root = pathlib.Path(__file__).resolve().parents[1] - heavyball_pkg = package_root / "heavyball" - saved_heavyball_modules = { - name: sys.modules[name] for name in list(sys.modules) if name == "heavyball" or name.startswith("heavyball.") - } - for name in list(sys.modules): - if name == "heavyball" or name.startswith("heavyball."): - sys.modules.pop(name) +# ==================================================================== +# CLI - real end-to-end (fresh heavyball import) +# ==================================================================== - spec = importlib.util.spec_from_file_location( - "heavyball", - heavyball_pkg / "__init__.py", - submodule_search_locations=[str(heavyball_pkg)], - ) - heavyball_module = importlib.util.module_from_spec(spec) - sys.modules["heavyball"] = heavyball_module - spec.loader.exec_module(heavyball_module) - try: - checkpoint_path = tmp_path / "legacy.pt" - output_path = tmp_path / "migrated.pt" - - legacy_state = { - "state": { - 0: { - "update_by_adam_exp_avg": torch.ones((2, 2), dtype=torch.float32), - "update_by_adam_exp_avg_sq": torch.full((2, 2), 2.0, dtype=torch.float32), - "is_initialized": [0], - }, - 1: { - "update_by_adam_exp_avg": torch.ones((2,), dtype=torch.float32), - "update_by_adam_exp_avg_sq": torch.full((2,), 2.0, dtype=torch.float32), - "is_initialized": [0], - }, - }, - "param_groups": [ - { - "params": [0, 1], - "lr": 0.0025, - "betas": [0.9, 0.99], - "eps": 1e-8, - "weight_decay": 0.0, - "warmup_steps": 0, - "foreach": True, - "storage_dtype": "float32", - "mars": False, - "caution": False, - "mars_gamma": 0.0025, - "gradient_clipping": "use_default", - "update_clipping": "use_default", - "palm": "use_default", - "beta2_scale": 0.8, - "__class__": "heavyball.ForeachAdamW", - } - ], - } +def _make_v1_checkpoint(path, shapes, *, group_overrides=None): + g = _v1_group(list(range(len(shapes)))) + g["multi_tensor"] = True + g.pop("foreach") + g.pop("stochastic_schedule") + if group_overrides: + g.update(group_overrides) + torch.save({"optimizer": {"state": _flat_state(*shapes), "param_groups": [g]}}, path) - torch.save({"optimizer": legacy_state}, checkpoint_path) - result = runner.invoke( - migrate_script.app, - [str(checkpoint_path), "heavyball.ForeachAdamW", "--output", str(output_path)], - ) +def _make_v2_checkpoint(path, shapes, *, group_overrides=None): + g = _v2_group(list(range(len(shapes)))) + if group_overrides: + g.update(group_overrides) + torch.save({"optimizer": {"state": _nested_state(*shapes), "param_groups": [g], "heavyball": _v2_meta()}}, path) + +def _make_v3_checkpoint(path, shapes, *, group_overrides=None): + g = _v3_group(list(range(len(shapes)))) + if group_overrides: + g.update(group_overrides) + torch.save({"optimizer": {"state": _nested_state(*shapes), "param_groups": [g], "heavyball": _v3_meta()}}, path) + + +def _run_cli_e2e(runner, ckpt, out, class_name): + saved = _load_heavyball_fresh() + try: + result = runner.invoke(migrate_script.app, [str(ckpt), class_name, "--output", str(out)]) assert result.exit_code == 0, result.stderr or result.stdout - assert output_path.exists() - - migrated = torch.load(output_path) - migrated_state = migrated["optimizer"] - - for pid, shape in [(0, (2, 2)), (1, (2,))]: - assert pid in migrated_state["state"], f"missing state for parameter {pid}" - migrated_bucket = migrated_state["state"][pid] - assert 0 in migrated_bucket, f"missing transformed view for parameter {pid}" - view_state = migrated_bucket[0] - assert view_state["update_by_adam_exp_avg_0"].shape == shape - assert view_state["update_by_adam_exp_avg_sq_0"].shape == shape - assert torch.allclose(view_state["update_by_adam_exp_avg_0"], torch.ones(shape)) - assert torch.allclose(view_state["update_by_adam_exp_avg_sq_0"], torch.full(shape, 2.0)) - assert view_state["is_initialized"] == [0] - - assert "heavyball" in migrated_state - assert migrated_state["heavyball"]["inner_group"]["stochastic_schedule"] is None - assert "Migrated checkpoint written to" in result.stdout + return torch.load(out)["optimizer"] finally: - for name in list(sys.modules): - if name == "heavyball" or name.startswith("heavyball."): - sys.modules.pop(name) - sys.modules.update(saved_heavyball_modules) + _restore_heavyball(saved) + + +# --- v1 end-to-end --- + + +def test_e2e_v1_adamw(runner, tmp_path): + ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt" + _make_v1_checkpoint(ckpt, [(2, 2), (2,)]) + migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW") + for pid, shape in [(0, (2, 2)), (1, (2,))]: + view = migrated["state"][pid][0] + assert view["update_by_adam_exp_avg_0"].shape == shape + assert view["update_by_adam_exp_avg_sq_0"].shape == shape + assert torch.allclose(view["update_by_adam_exp_avg_0"], torch.ones(shape)) + assert view["is_initialized"] == [0] + assert "heavyball" in migrated + + +@pytest.mark.parametrize("old_name,new_name", _DIRECT_RENAMES_NO_LRA) +def test_e2e_v1_all_direct_renames(runner, tmp_path, old_name, new_name): + ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt" + _make_v1_checkpoint(ckpt, [(4,)]) + migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}") + assert 0 in migrated["state"] + assert "heavyball" in migrated + + +@pytest.mark.parametrize("old_name,new_name", _DELETED_RENAMES_NO_LRA) +def test_e2e_v1_all_deleted_renames(runner, tmp_path, old_name, new_name): + ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt" + _make_v1_checkpoint(ckpt, [(4,)]) + migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}") + assert 0 in migrated["state"] + assert "heavyball" in migrated + + +# PSGDLRA v1 e2e: PSGDLRA computes rank from param numel during __init__, +# which the migration script's _instantiate_optimizer can't replicate from +# a state dict alone (rank isn't stored in param_groups). The rename +# resolution and migration logic for LRA is covered by the unit tests above. + + +@pytest.mark.parametrize("n_params", [1, 3, 5]) +def test_e2e_v1_varying_param_count(runner, tmp_path, n_params): + ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt" + shapes = [(i + 2,) for i in range(n_params)] + _make_v1_checkpoint(ckpt, shapes) + migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW") + for pid in range(n_params): + assert pid in migrated["state"] + + +@pytest.mark.parametrize("shape", [(4,), (2, 3), (2, 3, 4)]) +def test_e2e_v1_varying_shapes(runner, tmp_path, shape): + ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt" + _make_v1_checkpoint(ckpt, [shape]) + migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW") + assert migrated["state"][0][0]["update_by_adam_exp_avg_0"].shape == shape + + +# --- v2 end-to-end --- + + +def test_e2e_v2_basic(runner, tmp_path): + ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt" + _make_v2_checkpoint(ckpt, [(4,)]) + migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW") + group = migrated["param_groups"][0] + assert group["multi_tensor"] is True + assert "foreach" not in group + assert "stochastic_schedule" not in group + hb = migrated["heavyball"] + assert "precond_rng" not in hb + assert "stochastic_schedule" not in hb + + +@pytest.mark.parametrize("old_name,new_name", _DIRECT_RENAMES_NO_LRA) +def test_e2e_v2_all_direct_renames(runner, tmp_path, old_name, new_name): + ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt" + _make_v2_checkpoint(ckpt, [(4,)]) + migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}") + assert migrated["param_groups"][0]["multi_tensor"] is True + assert "foreach" not in migrated["param_groups"][0] + + +@pytest.mark.parametrize("n_params", [1, 3, 5]) +def test_e2e_v2_varying_param_count(runner, tmp_path, n_params): + ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt" + shapes = [(i + 2,) for i in range(n_params)] + _make_v2_checkpoint(ckpt, shapes) + migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW") + for pid in range(n_params): + assert pid in migrated["state"] + + +# --- v3 end-to-end (noop) --- + + +def test_e2e_v3_noop(runner, tmp_path): + ckpt, out = tmp_path / "v3.pt", tmp_path / "out.pt" + _make_v3_checkpoint(ckpt, [(4,)]) + migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW") + assert migrated["param_groups"][0]["multi_tensor"] is True + assert torch.equal(migrated["state"][0][0]["key_0"], torch.zeros(4)) diff --git a/test/test_optimizer_cpu_smoke.py b/test/test_optimizer_cpu_smoke.py index 9a9fa68..7875d5f 100644 --- a/test/test_optimizer_cpu_smoke.py +++ b/test/test_optimizer_cpu_smoke.py @@ -1,9 +1,11 @@ +import functools import inspect import pytest import torch import heavyball +from heavyball import chainable as C from heavyball.utils import StatefulOptimizer DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -31,7 +33,7 @@ def _optimizer_params(): ) ) continue - if name == "ForeachSOAPNAdam": + if name == "SOAPNAdam": params.append( pytest.param( name, @@ -80,3 +82,44 @@ def closure(): clone = opt_cls(model.parameters()) clone.load_state_dict(state_dict) assert clone.state_dict()["state"].keys() == state_dict["state"].keys() + + +def test_optimizer_keeps_constructor_compatibility_features(): + param = torch.nn.Parameter(torch.randn(4, 4, device=DEVICE)) + + with pytest.warns(FutureWarning, match="renamed to 'multi_tensor'"): + optimizer = heavyball.AdamW([param], foreach=True) + assert optimizer.param_groups[0]["multi_tensor"] is True + + with pytest.raises(TypeError, match="Removed in HeavyBall"): + heavyball.SOAP([param], normalize_grads=True) + + with pytest.warns(UserWarning, match="Working with uncaptured keyword arguments"): + heavyball.AdamW([param], totally_fake=True) + + +def test_optimizer_accepts_explicit_orig_shapes(): + param = torch.nn.Parameter(torch.randn(4, 4, device=DEVICE)) + shapes = heavyball.capture_param_shapes([param]) + optimizer = heavyball.AdamW([param], orig_shapes=shapes) + assert "orig_shapes" not in optimizer.param_groups[0] + + +def test_subclass_defaults_still_apply(): + class ScheduledSOAP(heavyball.SOAP): + use_precond_schedule = True + + class DelayedPSGDKron(heavyball.PSGDKron): + delayed = True + exp_avg_input = False + + param = torch.nn.Parameter(torch.randn(4, 4, device=DEVICE)) + + soap = ScheduledSOAP([param]) + assert "precondition_frequency" not in soap.param_groups[0] + assert "precond_scheduler" not in soap.param_groups[0] + + psgd = DelayedPSGDKron([param]) + first_fn = psgd.fns[0] + assert isinstance(first_fn, functools.partial) + assert first_fn.func.get_fn() is C.scale_by_delayed_psgd.get_fn() diff --git a/test/test_param_ecc_compile.py b/test/test_param_ecc_compile.py index 7182344..6b8c462 100644 --- a/test/test_param_ecc_compile.py +++ b/test/test_param_ecc_compile.py @@ -70,7 +70,7 @@ def _train_linear(compile_mode, rne=False, steps=50): try: with ctx: model = torch.nn.Linear(64, 32, bias=False, device="cuda") - opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, param_ecc="bf16+8") + opt = heavyball.AdamW(model.parameters(), lr=1e-2, param_ecc="bf16+8") data = torch.randn(32, 64, device="cuda") target = torch.randn(32, 32, device="cuda") for _ in range(steps): @@ -108,7 +108,7 @@ def test_ecc_populated_rne(): p, ecc, _ = _train_linear("max-autotune-no-cudagraphs", rne=True) assert ecc is not None, "param::ecc not found in optimizer state" assert ecc.any(), ( - "param::ecc all zeros with RNE rounding under compile — Inductor likely folded the f32->bf16->f32 roundtrip" + "param::ecc all zeros with RNE rounding under compile, Inductor likely folded the f32->bf16->f32 roundtrip" ) clean() diff --git a/test/test_stochastic_updates.py b/test/test_stochastic_updates.py deleted file mode 100644 index bd3f409..0000000 --- a/test/test_stochastic_updates.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -import torch -from torch import nn -from utils import REPRESENTATIVE_OPTS - -import heavyball -from heavyball.utils import clean, set_torch - -_SAVED_COMPILE_MODE = heavyball.utils.compile_mode -heavyball.utils.compile_mode = "default" - - -@pytest.fixture(autouse=True) -def _isolate_compile_mode(): - heavyball.utils.compile_mode = "default" - yield - heavyball.utils.compile_mode = _SAVED_COMPILE_MODE - - -PSGD_OPTS = [o for o in REPRESENTATIVE_OPTS if "PSGD" in o] - - -@pytest.mark.parametrize("opt", PSGD_OPTS) -def test_foreach(opt, size: int = 128, depth: int = 1, iterations: int = 512, outer_iterations: int = 2): - set_torch() - - opt = getattr(heavyball, opt) - - losses = [] - - for stochastic in [False, True]: - print("stochastic", stochastic) - torch.manual_seed(0x2131290) - losses.append([]) - - for i in range(outer_iterations): - model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda() - o = opt( - model.parameters(), - lr=1e-3, - stochastic_schedule=stochastic, - preconditioner_update_probability=lambda step: 0.1, - ) - - for _ in range(iterations): - loss = model(torch.randn((128, size), device="cuda")).square().mean() - loss.backward() - o.step() - o.zero_grad() - losses[-1].append(loss.detach()) - - del model, o - clean() - - stochastic = sum([l.item() for l in losses[1]]) - deterministic = sum([l.item() for l in losses[0]]) - print(f"{deterministic=}, {stochastic=}") - assert not torch.isclose(torch.tensor(deterministic), torch.tensor(stochastic), rtol=0.01) diff --git a/test/test_stochastic_utils_cpu.py b/test/test_stochastic_utils_cpu.py index 0b16f2d..693533d 100644 --- a/test/test_stochastic_utils_cpu.py +++ b/test/test_stochastic_utils_cpu.py @@ -1,8 +1,6 @@ -import os - -os.environ.setdefault("TORCH_COMPILE_DISABLE", "1") - +import pytest import torch +from torch._dynamo import config import heavyball from heavyball.utils import ( @@ -13,7 +11,15 @@ stochastic_divide_with_eps_, ) -heavyball.utils.atan2_scale = 1024.0 +config.cache_size_limit = 128 + + +@pytest.fixture(autouse=True) +def _restore_atan2_scale(): + orig = heavyball.utils.atan2_scale + heavyball.utils.atan2_scale = 1024.0 + yield + heavyball.utils.atan2_scale = orig def _average_stochastic_round(source: torch.Tensor, trials: int = 512) -> torch.Tensor: diff --git a/test/test_toy_training.py b/test/test_toy_training.py index 08859c5..6a729af 100644 --- a/test/test_toy_training.py +++ b/test/test_toy_training.py @@ -23,16 +23,14 @@ def _flatten_tensors(tensors: Iterable[torch.Tensor]): EXTRA_KWARGS = { - "ForeachAdamC": {"max_lr": 0.0025}, + "AdamC": {"max_lr": 0.0025}, } def _optimizer_params(): params = [] - for name in sorted(dir(heavyball)): - if name.startswith("_"): - continue - attr = getattr(heavyball, name) + for name in sorted(heavyball.__all__): + attr = getattr(heavyball, name, None) if not isinstance(attr, type) or not issubclass(attr, optim.Optimizer): continue if attr is optim.Optimizer: @@ -62,15 +60,15 @@ def toy_training_results(request): sig = inspect.signature(optimizer_cls.__init__) kwargs = dict(EXTRA_KWARGS.get(optimizer_name, {})) - if "foreach" in sig.parameters: - kwargs["foreach"] = True + if "multi_tensor" in sig.parameters: + kwargs["multi_tensor"] = True if optimizer_name == "SAMWrapper": inner_kwargs = {} - inner_sig = inspect.signature(heavyball.ForeachAdamW.__init__) - if "foreach" in inner_sig.parameters: - inner_kwargs["foreach"] = True - inner_optimizer = heavyball.ForeachAdamW(param_list, **inner_kwargs) + inner_sig = inspect.signature(heavyball.AdamW.__init__) + if "multi_tensor" in inner_sig.parameters: + inner_kwargs["multi_tensor"] = True + inner_optimizer = heavyball.AdamW(param_list, **inner_kwargs) optimizer = optimizer_cls(param_list, wrapped_optimizer=inner_optimizer, **kwargs) else: optimizer = optimizer_cls(param_list, **kwargs) diff --git a/test/test_utils_cpu.py b/test/test_utils_cpu.py index 54786a4..b9af3d7 100644 --- a/test/test_utils_cpu.py +++ b/test/test_utils_cpu.py @@ -1,5 +1,4 @@ import os -import random import warnings from copy import deepcopy @@ -40,6 +39,9 @@ # Ensure Torch dynamo stays disabled on CI runners without GPU support. os.environ.setdefault("TORCH_COMPILE_DISABLE", "1") +from torch._dynamo import config + +config.cache_size_limit = 128 _SAVED_COMPILE_MODE = heavyball.utils.compile_mode heavyball.utils.compile_mode = None @@ -173,9 +175,9 @@ def test_triu_line_roundtrip_on_cpu(): torch.arange(9, dtype=torch.float32).reshape(3, 3), ] packed = triu_to_line(tensors) - restored = line_to_triu(packed, symmetric_output=True) + restored = line_to_triu(packed) for original, rebuilt in zip(tensors, restored, strict=True): - assert torch.allclose(rebuilt, torch.triu(original) + torch.triu(original, diagonal=1).T) + assert torch.allclose(rebuilt, torch.triu(original)) def test_warn_once_only_emits_single_warning(monkeypatch): @@ -192,7 +194,7 @@ def test_warn_once_only_emits_single_warning(monkeypatch): def test_psgd_should_update_accumulates_probability(): - group = {"stochastic_schedule": False} + group = {} outcomes = [psgd_should_update(group, 0.4) for _ in range(4)] assert outcomes[:2] == [False, False] assert outcomes[2] is True @@ -200,15 +202,6 @@ def test_psgd_should_update_accumulates_probability(): assert group["cumulative_prob_prob_step"] == 4 -def test_psgd_should_update_stochastic_schedule_uses_rng(): - rng = random.Random(123) - group = {"stochastic_schedule": True} - calls = [psgd_should_update(group, 0.5, rng=rng) for _ in range(5)] - rng = random.Random(123) - expected = [rng.random() < 0.5 for _ in range(5)] - assert calls == expected - - def test_stochastic_math_helpers_match_expected_results(n=1024): torch.manual_seed(0x172893) a = torch.arange(n).float() @@ -226,7 +219,8 @@ def test_stochastic_math_helpers_match_expected_results(n=1024): stochastic_add_divide_(c, b, alpha=1.0, divisor=2.0) assert torch.allclose(c.float(), (a + b * 1) / 2) - orig = heavyball.utils.default_division_backend + orig_backend = heavyball.utils.default_division_backend + orig_scale = heavyball.utils.atan2_scale try: heavyball.utils.atan2_scale = 1024 for backend in heavyball.utils.DivisionBackend: @@ -235,7 +229,8 @@ def test_stochastic_math_helpers_match_expected_results(n=1024): stochastic_divide_with_eps_(c, b) assert torch.allclose(c.float(), a / b), f"Backend {backend} failed" finally: - heavyball.utils.default_division_backend = orig + heavyball.utils.default_division_backend = orig_backend + heavyball.utils.atan2_scale = orig_scale def test_stochastic_math_accuracy(): diff --git a/test/test_utils_property.py b/test/test_utils_property.py index d0c25cb..ec6c3e5 100644 --- a/test/test_utils_property.py +++ b/test/test_utils_property.py @@ -22,6 +22,9 @@ # Ensure torch.compile stays disabled on CPU-only CI runners. os.environ.setdefault("TORCH_COMPILE_DISABLE", "1") +from torch._dynamo import config + +config.cache_size_limit = 128 heavyball.utils.compile_mode = None diff --git a/test/utils.py b/test/utils.py index 800c2fc..b00ff8c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -13,17 +13,10 @@ # AdEMAMix variants require 3 betas, SplitOpt requires dict param specs, # SAMWrapper and Newton variants require closures _SKIP_GET_OPTIM = { - "ForeachAdEMAMix", - "ForeachSOAPAdEMAMix", + "AdEMAMix", "SOAPAdEMAMix", "SplitOpt", "SAMWrapper", - "ForeachCachedNewtonPSGD", - "NewtonHybrid2PSGDKron", - "ForeachNewtonPSGDLRA", - "NewtonHybrid2PSGDLRA", - "NewtonPSGDLRA", - "NewtonPSGDKron", } @@ -40,8 +33,8 @@ def _fn_key(f): def _deduplicate_by_chain(names): """Keep one optimizer per unique chain of functions. - Two optimizers that differ only by foreach=True/False have identical - chains and test the same code paths — keep whichever appears first. + Two optimizers that differ only by multi_tensor=True/False have identical + chains and test the same code paths, keep whichever appears first. """ seen = set() out = []