`
+7. If you relied on `eval(); eval()` toggling back to train mode, update your code
+8. If your training loop reads `p.grad` after `step()`, pass `consume_grad=False`
+9. Remove obsolete compatibility kwargs from `heavyball.helpers` samplers
diff --git a/examples/autoencoder.py b/examples/autoencoder.py
index 20350cb..615ddba 100644
--- a/examples/autoencoder.py
+++ b/examples/autoencoder.py
@@ -96,7 +96,6 @@ def main(epochs: int, batch: int, log_interval: int = 16):
lr=1e-4,
mars=True,
lower_bound_beta=0.9,
- inverse_free=True,
precond_update_power_iterations=6,
store_triu_as_line=False,
)
diff --git a/examples/branched_optimizer.py b/examples/branched_optimizer.py
index b8a2795..73aa271 100644
--- a/examples/branched_optimizer.py
+++ b/examples/branched_optimizer.py
@@ -29,7 +29,7 @@ def __init__(
eps: float = 1e-8,
weight_decay: float = 1e-4,
warmup_steps: int = 0,
- foreach: bool = True,
+ multi_tensor: bool = True,
):
defaults = dict(
lr=lr,
@@ -38,8 +38,8 @@ def __init__(
weight_decay=weight_decay,
warmup_steps=warmup_steps,
)
- branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=_graft)
- super().__init__(params, defaults, foreach, fns=(branch,))
+ branch = C.Parallel(branches=[[C.scale_by_adam], [C.identity]], merge_fn=_graft)
+ super().__init__(params, defaults, multi_tensor, fns=(branch,))
def main(epochs: int = 20, batch_size: int = 256, subset_size: int = 4096):
diff --git a/examples/ddp_training.py b/examples/ddp_training.py
index dc27092..bc39932 100644
--- a/examples/ddp_training.py
+++ b/examples/ddp_training.py
@@ -2,8 +2,8 @@
Launch with torchrun:
torchrun --nproc_per_node=2 examples/ddp_training.py
- torchrun --nproc_per_node=2 examples/ddp_training.py --opt ForeachSOAP
- torchrun --nproc_per_node=2 examples/ddp_training.py --opt ForeachMuon --lr 0.01
+ torchrun --nproc_per_node=2 examples/ddp_training.py --opt SOAP
+ torchrun --nproc_per_node=2 examples/ddp_training.py --opt Muon --lr 0.01
All HeavyBall optimizers work transparently with DDP
"""
@@ -38,7 +38,7 @@ def make_model():
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--opt", default="ForeachAdamW")
+ parser.add_argument("--opt", default="AdamW")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
diff --git a/examples/ecc_bf16.py b/examples/ecc_bf16.py
index 1b8d1bf..7effd5d 100644
--- a/examples/ecc_bf16.py
+++ b/examples/ecc_bf16.py
@@ -19,11 +19,11 @@
CONFIGS = {
"naive_fp32": lambda p: NaiveAdamW(p, lr=LR, betas=BETAS, eps=EPS, state_dtype=torch.float32),
"naive_bf16": lambda p: NaiveAdamW(p, lr=LR, betas=BETAS, eps=EPS, state_dtype=torch.bfloat16),
- "heavyball_fp32": lambda p: heavyball.ForeachAdamW(p, lr=LR, betas=BETAS, eps=EPS, storage_dtype="float32"),
- "heavyball_bf16": lambda p: heavyball.ForeachAdamW(
+ "heavyball_fp32": lambda p: heavyball.AdamW(p, lr=LR, betas=BETAS, eps=EPS, storage_dtype="float32"),
+ "heavyball_bf16": lambda p: heavyball.AdamW(
p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, storage_dtype="bfloat16"
),
- "ecc_bf16+8": lambda p: heavyball.ForeachAdamW(p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, ecc="bf16+8"),
+ "ecc_bf16+8": lambda p: heavyball.AdamW(p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, ecc="bf16+8"),
}
COLORS = {
diff --git a/examples/fsdp_training.py b/examples/fsdp_training.py
index 375d0e8..06087de 100644
--- a/examples/fsdp_training.py
+++ b/examples/fsdp_training.py
@@ -2,8 +2,8 @@
Launch with torchrun:
torchrun --nproc_per_node=2 examples/fsdp_training.py
- torchrun --nproc_per_node=2 examples/fsdp_training.py --opt ForeachSOAP
- torchrun --nproc_per_node=2 examples/fsdp_training.py --opt ForeachMuon --lr 0.01
+ torchrun --nproc_per_node=2 examples/fsdp_training.py --opt SOAP
+ torchrun --nproc_per_node=2 examples/fsdp_training.py --opt Muon --lr 0.01
Shape-aware optimizers (SOAP, Muon, PSGD, Scion, etc.) auto-detect FSDP-flattened
params and restore original shapes. No manual intervention needed.
@@ -11,7 +11,7 @@
For non-FSDP parallelism backends, capture shapes before wrapping:
shapes = heavyball.capture_param_shapes(model)
model = your_wrapper(model)
- opt = heavyball.ForeachSOAP(model.parameters(), lr=3e-3, orig_shapes=shapes)
+ opt = heavyball.SOAP(model.parameters(), lr=3e-3, orig_shapes=shapes)
"""
import argparse
@@ -44,7 +44,7 @@ def make_model():
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--opt", default="ForeachAdamW")
+ parser.add_argument("--opt", default="AdamW")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index e8c3eaf..db48ee3 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -7,6 +7,8 @@
from . import chainable as C
from . import utils
+ShapeMap = dict[int, tuple[int, ...]]
+
class SGD(C.BaseOpt):
def __init__(
@@ -16,7 +18,7 @@ def __init__(
beta=0.9,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -27,13 +29,14 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
-class ForeachAdamW(C.BaseOpt):
+class AdamW(C.BaseOpt):
def __init__(
self,
params,
@@ -42,7 +45,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -55,13 +58,14 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
-class ForeachNAdam(C.BaseOpt):
+class NAdam(C.BaseOpt):
def __init__(
self,
params,
@@ -72,7 +76,7 @@ def __init__(
momentum_decay: float = 4e-3,
decoupled_weight_decay: bool = False,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -85,13 +89,14 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,))
-class ForeachAdEMAMix(C.BaseOpt):
+class AdEMAMix(C.BaseOpt):
def __init__(
self,
params,
@@ -103,7 +108,7 @@ def __init__(
beta3_warmup: Optional[int] = None,
alpha_warmup: Optional[int] = None,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -114,13 +119,14 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
if len(betas) != 3:
raise ValueError("AdEMAMix expects betas with three coefficients.")
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,))
class UnscaledAdamW(C.BaseOpt):
@@ -132,7 +138,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -145,12 +151,11 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(
- params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
- )
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,))
class SUDSAdamW(C.BaseOpt):
@@ -162,7 +167,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -176,10 +181,11 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
class Scion(C.BaseOpt):
@@ -191,7 +197,7 @@ def __init__(
eps: float = 1e-8,
weight_decay: float = 0,
warmup_steps: int = 0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -204,6 +210,7 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
if lr < 0:
@@ -221,12 +228,10 @@ def __init__(
defaults["scale"] = scale
defaults.pop("momentum", None)
- super().__init__(
- params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm)
- )
+ super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm))
-class ForeachAdamC(C.BaseOpt):
+class AdamC(C.BaseOpt):
def __init__(
self,
params,
@@ -236,7 +241,7 @@ def __init__(
weight_decay=0,
max_lr: float | None = None,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -249,6 +254,7 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
if max_lr is None:
@@ -258,10 +264,10 @@ def __init__(
max_lr = lr
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
-class ForeachRMSprop(C.BaseOpt):
+class RMSprop(C.BaseOpt):
"""
Debiased RMSprop (not torch.optim.RMSprop)
"""
@@ -276,7 +282,7 @@ def __init__(
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -289,13 +295,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -303,7 +309,92 @@ def __init__(
)
-class ForeachSFAdamW(C.ScheduleFree):
+class HyperBallAdamW(C.BaseOpt):
+ def __init__(
+ self,
+ params,
+ lr=0.0025,
+ betas=(0.9, 0.99),
+ eps=1e-8,
+ weight_decay=0,
+ warmup_steps=0,
+ multi_tensor: bool = True,
+ storage_dtype: str = "float32",
+ mars: bool = False,
+ caution: bool = False,
+ mars_gamma: float = 0.0025,
+ gradient_clipping: C.str_or_fn = C.use_default,
+ update_clipping: C.str_or_fn = C.use_default,
+ palm: bool = C.use_default,
+ beta2_scale: float = 0.8,
+ compile_step: bool = C.use_default,
+ promote: bool = C.use_default,
+ ecc: str | None = None,
+ param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
+ **kwargs,
+ ):
+ params, defaults = C._build_defaults(locals())
+ super().__init__(
+ params,
+ defaults,
+ gradient_clipping,
+ update_clipping,
+ palm,
+ fns=(
+ C.scale_by_exp_avg_sq,
+ C.route(
+ (lambda p: p.ndim >= 2, C.update_by_hyperball),
+ default=C.apply_update,
+ ),
+ ),
+ )
+
+
+class MuonAdamW(C.BaseOpt):
+ def __init__(
+ self,
+ params,
+ lr=0.0025,
+ betas=(0.9, 0.99),
+ eps=1e-8,
+ weight_decay=0,
+ warmup_steps=0,
+ multi_tensor: bool = True,
+ storage_dtype: str = "float32",
+ mars: bool = False,
+ caution: bool = False,
+ mars_gamma: float = 0.0025,
+ gradient_clipping: C.str_or_fn = C.use_default,
+ update_clipping: C.str_or_fn = C.use_default,
+ palm: bool = C.use_default,
+ beta2_scale: float = 0.8,
+ nesterov: bool = True,
+ compile_step: bool = C.use_default,
+ promote: bool = C.use_default,
+ ecc: str | None = None,
+ param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
+ **kwargs,
+ ):
+ params, defaults = C._build_defaults(locals())
+ ema = C.nesterov_ema if nesterov else C.exp_avg
+ super().__init__(
+ params,
+ defaults,
+ gradient_clipping,
+ update_clipping,
+ palm,
+ fns=(
+ C.route(
+ (lambda p: p.ndim >= 2, (ema, C.orthogonalize_update)),
+ default=C.scale_by_adam,
+ ),
+ ),
+ )
+
+
+class SFAdamW(C.ScheduleFree):
def __init__(
self,
params,
@@ -314,7 +405,7 @@ def __init__(
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -327,13 +418,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -352,7 +443,7 @@ def __init__(
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -366,13 +457,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -380,11 +471,7 @@ def __init__(
)
-class PaLMForeachSFAdamW(ForeachSFAdamW):
- palm: bool = True
-
-
-class ForeachADOPT(C.BaseOpt):
+class ADOPT(C.BaseOpt):
def __init__(
self,
params,
@@ -393,7 +480,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -406,13 +493,14 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
-class ForeachMuon(C.BaseOpt):
+class Muon(C.BaseOpt):
def __init__(
self,
params,
@@ -421,7 +509,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -436,6 +524,7 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
@@ -450,7 +539,6 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -458,7 +546,7 @@ def __init__(
)
-class ForeachLaProp(C.BaseOpt):
+class LaProp(C.BaseOpt):
def __init__(
self,
params,
@@ -467,7 +555,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -480,10 +568,11 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
class MuonLaProp(C.BaseOpt):
@@ -495,7 +584,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -508,13 +597,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -522,9 +611,31 @@ def __init__(
)
-class ForeachSOAP(C.BaseOpt):
+class SOAPBase(C.BaseOpt):
+ use_precond_schedule: bool = False
+
+ def _build_soap_defaults(self, locals_dict, fns):
+ use_precond_schedule = C.default(locals_dict["use_precond_schedule"], self.use_precond_schedule)
+ params, defaults = C._build_defaults(locals_dict)
+ if use_precond_schedule:
+ del defaults["precondition_frequency"]
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
+ else:
+ del defaults["precond_scheduler"]
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
+ super().__init__(
+ params,
+ defaults,
+ locals_dict["gradient_clipping"],
+ locals_dict["update_clipping"],
+ locals_dict.get("palm", False),
+ fns=fns,
+ )
+
+
+class SOAP(SOAPBase):
"""
- ForeachSOAP
+ SOAP
Sources:
Baseline SOAP:
@@ -534,8 +645,6 @@ class ForeachSOAP(C.BaseOpt):
https://github.com/nikhilvyas/SOAP
"""
- use_precond_schedule: bool = False
-
def __init__(
self,
params,
@@ -548,11 +657,9 @@ def __init__(
max_precond_dim: int = 2048, #
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -563,38 +670,18 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
precond_grad_accum: bool = False,
compile_step: bool = C.use_default,
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
-
- params, defaults = C._build_defaults(locals())
-
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- foreach,
- gradient_clipping,
- update_clipping,
- palm, #
- fns=(C.scale_by_soap,),
- )
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap,))
-class ForeachSOAPNAdam(C.BaseOpt):
- use_precond_schedule: bool = False
-
+class SOAPNAdam(SOAPBase):
def __init__(
self,
params,
@@ -607,11 +694,9 @@ def __init__(
max_precond_dim: int = 2048,
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -622,7 +707,6 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
precond_grad_accum: bool = False,
momentum_decay: float = 4e-3,
decoupled_weight_decay: bool = False,
@@ -630,32 +714,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap_nadam,))
- params, defaults = C._build_defaults(locals())
-
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- foreach,
- gradient_clipping,
- update_clipping,
- palm,
- fns=(C.scale_by_soap_nadam,),
- )
-
-
-class ForeachSOAPAdEMAMix(C.BaseOpt):
- use_precond_schedule: bool = False
+class SOAPAdEMAMix(SOAPBase):
def __init__(
self,
params,
@@ -668,11 +733,9 @@ def __init__(
max_precond_dim: int = 2048,
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -683,7 +746,6 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
precond_grad_accum: bool = False,
alpha: float = 2.0,
beta3_warmup: int | None = None,
@@ -692,30 +754,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap_ademamix,))
- params, defaults = C._build_defaults(locals())
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- foreach,
- gradient_clipping,
- update_clipping,
- palm,
- fns=(C.scale_by_soap_ademamix,),
- )
-
-
-class ForeachSignLaProp(C.BaseOpt):
+class SignLaProp(C.BaseOpt):
def __init__(
self,
params,
@@ -724,7 +769,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -737,13 +782,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -751,9 +796,9 @@ def __init__(
)
-class ForeachSOLP(C.BaseOpt):
+class SOLP(SOAPBase):
"""
- ForeachSOLP
+ SOLP
Sources:
Baseline SOAP:
@@ -763,8 +808,6 @@ class ForeachSOLP(C.BaseOpt):
https://github.com/nikhilvyas/SOAP
"""
- use_precond_schedule: bool = False
-
def __init__(
self,
params,
@@ -777,11 +820,9 @@ def __init__(
max_precond_dim: int = 2048, #
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -792,46 +833,14 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
compile_step: bool = C.use_default,
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
-
- params, defaults = C._build_defaults(locals())
-
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- foreach,
- gradient_clipping,
- update_clipping,
- palm, #
- fns=(C.scale_by_soap_laprop,),
- )
-
-
-class PaLMForeachSOAP(ForeachSOAP):
- use_precond_schedule: bool = False
- palm: bool = True
-
-
-class PrecondScheduleForeachSOAP(ForeachSOAP):
- use_precond_schedule: bool = True
-
-
-class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
- use_precond_schedule: bool = True
- palm: bool = True
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap_laprop,))
class OrthoLaProp(C.BaseOpt):
@@ -843,7 +852,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -856,13 +865,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -879,7 +888,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -892,13 +901,13 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
params, defaults = C._build_defaults(locals())
super().__init__(
params,
defaults,
- foreach,
gradient_clipping,
update_clipping,
palm,
@@ -906,18 +915,43 @@ def __init__(
)
-class ForeachPSGDKron(C.BaseOpt):
+class PSGDBase(C.BaseOpt):
+ delayed: bool = False
+ cached: bool = False
+ exp_avg_input: bool = True
+
+ def _build_psgd_defaults(
+ self, locals_dict, fns, *, default_update_clipping=utils.trust_region_clip_, extra_defaults=None
+ ):
+ exp_avg_input = C.default(locals_dict.get("exp_avg_input", C.use_default), self.exp_avg_input)
+ update_clipping = C.default(locals_dict["update_clipping"], default_update_clipping)
+ locals_dict = {**locals_dict, "exp_avg_input": exp_avg_input, "update_clipping": update_clipping}
+ params, defaults = C._build_defaults(locals_dict)
+
+ self.precond_schedule = C.default(
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
+ )
+
+ if extra_defaults:
+ defaults.update(extra_defaults)
+
+ super().__init__(
+ params,
+ defaults,
+ locals_dict["gradient_clipping"],
+ update_clipping,
+ False,
+ fns=(*(C.exp_avg,) * exp_avg_input, *fns),
+ )
+
+
+class PSGDKron(PSGDBase):
"""
Originally from Evan Walters and Omead Pooladzandi, 2024
Modified under Creative Commons Attribution 4.0 International
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
"""
- delayed: bool = False
- cached: bool = False
- exp_avg_input: bool = True
- quad: bool = False
-
def __init__(
self,
params,
@@ -934,9 +968,8 @@ def __init__(
merge_dims: bool = False,
split: bool = False,
store_triu_as_line: bool = True,
- foreach: bool = True,
+ multi_tensor: bool = True,
q_dtype="float32",
- stochastic_schedule: bool = False,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -946,12 +979,9 @@ def __init__(
exp_avg_input: Optional[bool] = C.use_default,
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default, #
- adaptive: bool = False,
- ortho_method: Optional[str] = None, # If None, no orthogonalization
precond_grad_accum: bool = False,
- lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD
- inverse_free: bool = C.use_default,
- dampening: float = 2**-13,
+ lower_bound_beta: float = 0.9,
+ dampening: float = 1e-9,
precond_update_power_iterations: int = 2,
# expert parameters
precond_init_scale=None,
@@ -966,77 +996,140 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
delayed = C.default(delayed, self.delayed)
cached = C.default(cached, self.cached)
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
- inverse_free = C.default(inverse_free, self.quad)
- if inverse_free:
- raise ValueError(
- "inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
- )
-
- params, defaults = C._build_defaults(locals())
-
- self.precond_schedule = C.default(
- defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
+ self._build_psgd_defaults(
+ locals(),
+ fns=(functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),),
)
- super().__init__(
- params,
- defaults,
- foreach,
- gradient_clipping,
- update_clipping,
- False, #
- fns=(
- *(C.exp_avg,) * exp_avg_input,
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
- ),
- )
-
-
-class ForeachPurePSGD(ForeachPSGDKron):
- exp_avg_input: bool = False
-
-
-class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
- delayed: bool = True
- cached: bool = True
-
-
-class ForeachCachedPSGDKron(ForeachPSGDKron):
- cached: bool = True
+class LATHER(PSGDBase):
+ """
+ Lie-group Adam Through Harmonic Eigenbasis Rotations.
+ Runs Adam in the approximate eigenspace induced by the PSGD-Kron preconditioner, then maps back.
+ """
-class ForeachDelayedPSGD(ForeachPSGDKron):
- delayed: bool = True
-
+ def __init__(
+ self,
+ params,
+ lr=0.001,
+ beta=None,
+ betas=(0.9, 0.999),
+ eps: float = 1e-8,
+ weight_decay=0.0,
+ preconditioner_update_probability=C.use_default,
+ max_size_triangular=2048,
+ min_ndim_triangular=2,
+ memory_save_mode=None,
+ momentum_into_precond_update=True,
+ warmup_steps: int = 0,
+ merge_dims: bool = False,
+ split: bool = False,
+ store_triu_as_line: bool = True,
+ multi_tensor: bool = True,
+ q_dtype="float32",
+ storage_dtype: str = "float32",
+ mars: bool = False,
+ caution: bool = False,
+ mars_gamma: float = 0.0025,
+ gradient_clipping: C.str_or_fn = C.use_default,
+ update_clipping: C.str_or_fn = C.use_default,
+ precond_grad_accum: bool = False,
+ lower_bound_beta: float = 0.9,
+ dampening: float = 1e-9,
+ precond_update_power_iterations: int = 2,
+ precond_init_scale=None,
+ precond_init_scale_scale: float = 1,
+ precond_init_scale_power: Optional[float] = None,
+ precond_lr: float = 0.1,
+ finite_differences: bool = C.use_default,
+ fallback_to_finite_differences: bool = C.use_default,
+ hvp_interval: int = C.use_default,
+ hessian_approx: bool = C.use_default,
+ compile_step: bool = C.use_default,
+ promote: bool = C.use_default,
+ ecc: str | None = None,
+ param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
+ **kwargs,
+ ):
+ self._build_psgd_defaults(
+ {**locals(), "exp_avg_input": False},
+ fns=(C.scale_by_lather,),
+ )
-class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
- hessian_approx = True
+class PSGDPRO(PSGDBase):
+ """
+ PSGD with Q0.5EQ1.5 (PRO/Procrustes) preconditioner update.
+ Solve-free alternative to standard PSGD-Kron (EQ method).
+ Reference: https://github.com/lixilinx/psgd_torch
+ """
-class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
- hvp_interval = 2
+ def __init__(
+ self,
+ params,
+ lr=0.001,
+ beta=None,
+ betas=(0.9, 0.999),
+ weight_decay=0.0,
+ preconditioner_update_probability=C.use_default,
+ max_size_triangular=2048,
+ min_ndim_triangular=2,
+ memory_save_mode=None,
+ momentum_into_precond_update=True,
+ warmup_steps: int = 0,
+ merge_dims: bool = False,
+ split: bool = False,
+ multi_tensor: bool = True,
+ q_dtype="float32",
+ storage_dtype: str = "float32",
+ mars: bool = False,
+ caution: bool = False,
+ mars_gamma: float = 0.0025,
+ cached: Optional[bool] = C.use_default,
+ exp_avg_input: Optional[bool] = C.use_default,
+ gradient_clipping: C.str_or_fn = C.use_default,
+ update_clipping: C.str_or_fn = C.use_default,
+ precond_grad_accum: bool = False,
+ lower_bound_beta: float = 0.9,
+ dampening: float = 1e-9,
+ precond_update_power_iterations: int = 2,
+ precond_init_scale=None,
+ precond_init_scale_scale: float = 1,
+ precond_init_scale_power: Optional[float] = None,
+ precond_lr: float = 0.1,
+ compile_step: bool = C.use_default,
+ promote: bool = C.use_default,
+ ecc: str | None = None,
+ param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
+ **kwargs,
+ ):
+ cached = C.default(cached, self.cached)
+ self._build_psgd_defaults(
+ locals(),
+ fns=(functools.partial(C.scale_by_psgd_pro, cached=cached),),
+ default_update_clipping=None,
+ extra_defaults={"store_triu_as_line": False},
+ )
-class ForeachPSGDLRA(C.BaseOpt):
+class PSGDLRA(PSGDBase):
"""
Originally from Evan Walters and Omead Pooladzandi, 2024
Modified under Creative Commons Attribution 4.0 International
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
- Note: `foreach=True` (default) uses a single global low-rank approximation shared across all
- parameters, while `foreach=False` fits an independent per-parameter LRA. These are different
+ Note: `multi_tensor=True` (default) uses a single global low-rank approximation shared across all
+ parameters, while `multi_tensor=False` fits an independent per-parameter LRA. These are different
algorithms and will produce different results.
"""
- delayed: bool = False
- exp_avg_input: bool = True
-
def __init__(
self,
params,
@@ -1047,9 +1140,8 @@ def __init__(
momentum_into_precond_update=True,
rank: Optional[int] = None,
warmup_steps: int = 0,
- foreach: bool = True, # True: global LRA across all params. False: independent per-param LRA.
+ multi_tensor: bool = True, # True: global LRA across all params. False: independent per-param LRA.
q_dtype="float32",
- stochastic_schedule: bool = False,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -1058,8 +1150,8 @@ def __init__(
exp_avg_input: Optional[bool] = C.use_default,
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
- eps: float = 1e-8, #
- precond_grad_accum: bool = False, # expert parameters
+ eps: float = 1e-8,
+ precond_grad_accum: bool = False,
precond_init_scale=None,
precond_init_scale_scale: float = 1,
precond_init_scale_power: Optional[float] = None,
@@ -1072,49 +1164,24 @@ def __init__(
promote: bool = C.use_default,
ecc: str | None = None,
param_ecc: str | None = None,
+ orig_shapes: ShapeMap | None = None,
**kwargs,
):
delayed = C.default(delayed, self.delayed)
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
-
- params, defaults = C._build_defaults(locals())
-
- self.precond_schedule = C.default(
- defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
- )
-
if rank is None:
utils.warn_once(
f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
)
params = list(params)
- defaults["rank"] = max(1, round(math.log2(sum(p.numel() for p in params))))
- utils.warn_once(f"rank was set to {defaults['rank']}")
+ rank = max(1, round(math.log2(sum(p.numel() for p in params))))
+ utils.warn_once(f"rank was set to {rank}")
- super().__init__(
- params,
- defaults,
- foreach,
- gradient_clipping,
- update_clipping,
- False, #
- fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
+ self._build_psgd_defaults(
+ locals(),
+ fns=(C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,),
)
-class ForeachDelayedPSGDLRA(ForeachPSGDLRA):
- delayed: bool = True
-
-
-class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
- hessian_approx = True
-
-
-class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
- hvp_interval = 2
-
-
class SplitOpt(utils.StatefulOptimizer):
"""
Delegates different parameter groups to different underlying optimizers.
@@ -1135,7 +1202,7 @@ def __init__(self, specs):
all_params.extend(params)
if not self.optimizers:
raise ValueError("No optimizers created")
- super().__init__(all_params, {}, foreach=True)
+ super().__init__(all_params, {"multi_tensor": True})
def _step(self, group):
pass
@@ -1166,7 +1233,7 @@ class SAMWrapper(torch.optim.Optimizer):
def __init__(
self,
params,
- wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
+ wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = AdamW,
ball: float = 0.1,
):
params = list(params)
@@ -1210,31 +1277,10 @@ def zero_grad(self, set_to_none: bool = True):
self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
-PalmForEachSoap = PaLMForeachSOAP
-PaLMSOAP = PaLMForeachSOAP
-PaLMSFAdamW = PaLMForeachSFAdamW
-SOAP = ForeachSOAP
-SOAPAdEMAMix = ForeachSOAPAdEMAMix
-SOAPNAdam = ForeachSOAPNAdam
-SFAdamW = ForeachSFAdamW
-LaProp = ForeachLaProp
-ADOPT = ForeachADOPT
-RMSprop = ForeachRMSprop
-PrecondScheduleSOAP = PrecondScheduleForeachSOAP
-PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
-PSGDKron = ForeachPSGDKron
-AdamW = ForeachAdamW
-NAdam = ForeachNAdam
-PurePSGD = ForeachPurePSGD
-DelayedPSGD = ForeachDelayedPSGD
-CachedPSGDKron = ForeachCachedPSGDKron
-CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
-Muon = ForeachMuon
-SignLaProp = ForeachSignLaProp
-DelayedPSGDLRA = ForeachDelayedPSGDLRA
-PSGDLRA = ForeachPSGDLRA
-NewtonPSGDLRA = ForeachNewtonPSGDLRA
-NewtonPSGDKron = ForeachCachedNewtonPSGD
-
capture_param_shapes = utils.capture_param_shapes
-__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
+_BASE_CLASSES = {SOAPBase, PSGDBase}
+__all__ = [
+ k
+ for k, v in globals().items()
+ if isinstance(v, type) and issubclass(v, torch.optim.Optimizer) and v not in _BASE_CLASSES
+]
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index eca7ea7..100c7c2 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -2,7 +2,7 @@
import copy
import functools
import math
-import random
+import warnings
from collections.abc import Iterable as _Iterable
from typing import Iterable, List, Literal, Optional, Union
@@ -25,13 +25,6 @@ def _key_in_state(state, key):
return True
-def _inplace_guard_(state, key, template_fn):
- key_not_in_state = not _key_in_state(state, key)
- if key_not_in_state:
- template_fn()
- return key_not_in_state
-
-
def _guard_in_state(state, key, template_fn):
if not _key_in_state(state, key):
state[key] = template_fn()
@@ -45,7 +38,6 @@ def __init__(self, fn, names: list[str] | None = None):
self.fn = fn
self.fn_name = self.get_fn().__name__
self.transform_idx = None
- self.is_initialized = False
self.names = names
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
@@ -55,7 +47,7 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
raise NotImplementedError
def __call__(self, state, group, update, grad, param, *args, **kwargs):
- states = [state(p) for p in param]
+ states = state if isinstance(state, list) else [state(p) for p in param]
skip_update = False
for st, a in zip(states, zip(update, grad, param, *args)):
if self.transform_idx not in st.get("is_initialized", set()):
@@ -77,28 +69,106 @@ def get_fn(self):
return self.fn.get_fn()
return self.fn
+ def _build_val_names(self):
+ self._val_names = {name: f"{self.fn_name}_{name}_{self.transform_idx}" for name in self.names}
+
def val_name(self, name):
- assert self.transform_idx is not None
- return f"{self.fn_name}_{name}_{self.transform_idx}"
+ return self._val_names[name]
def __repr__(self):
return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
-class Branch:
+def _enforce_uniform_skip(results):
+ skips = [skip for _, skip, _ in results]
+ if not skips:
+ return False
+ if all(skips):
+ return True
+ if not any(skips):
+ return False
+ raise ValueError("All branches must uniformly skip or not skip updates")
+
+
+def _normalize_chain(fns):
+ if fns is None:
+ return None
+ return fns if isinstance(fns, (list, tuple)) else (fns,)
+
+
+class Parallel:
def __init__(self, branches: List[List[callable]], merge_fn: callable):
self.branches = branches
self.merge_fn = merge_fn
def __call__(self, state, group, update, grad, param):
- outputs = []
+ results = []
for branch in self.branches:
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
- branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
- if skip_update:
- raise ValueError("Branches should not skip updates")
- outputs.append(branch_update)
- return self.merge_fn(outputs)
+ u, skip = _inner_chain(state, group, branch_update, grad, param, *branch)
+ results.append((u, skip, None))
+ if _enforce_uniform_skip(results):
+ raise SkipUpdate from None
+ return self.merge_fn([u for u, _, _ in results])
+
+
+class Route:
+ """Route params by predicate through different fn chains.
+
+ Takes arbitrary (predicate, fns) pairs. Each param is assigned to the first
+ matching route; unmatched params use the default chain (None = passthrough).
+ All routes must uniformly either skip or not skip updates.
+ """
+
+ def __init__(self, *routes, default=None):
+ self.routes = [(pred, _normalize_chain(fns)) for pred, fns in routes]
+ self.default = _normalize_chain(default)
+
+ def __call__(self, state, group, update, grad, param):
+ buckets = {}
+ assigned = set()
+ for j, (pred, _) in enumerate(self.routes):
+ for i, p in enumerate(param):
+ if i not in assigned and pred(p):
+ buckets.setdefault(j, []).append(i)
+ assigned.add(i)
+ default_idx = [i for i in range(len(param)) if i not in assigned]
+
+ def _sel(lst, idx):
+ return [lst[i] for i in idx]
+
+ caution = group["caution"]
+ results = []
+
+ all_chains = [(buckets.get(j), fns) for j, (_, fns) in enumerate(self.routes)]
+ if default_idx:
+ all_chains.append((default_idx, self.default))
+
+ for idx, fns in all_chains:
+ if not idx:
+ continue
+ group["caution"] = caution
+ if fns is not None:
+ u, skip = _inner_chain(
+ _sel(state, idx), group, _sel(update, idx), _sel(grad, idx), _sel(param, idx), *fns
+ )
+ else:
+ u, skip = _sel(update, idx), False
+ results.append((u, skip, idx))
+
+ if _enforce_uniform_skip(results):
+ raise SkipUpdate from None
+
+ out = [None] * len(param)
+ for u_list, _, idx in results:
+ if u_list is not None:
+ for i, u in zip(idx, u_list):
+ out[i] = u
+ return out
+
+
+def route(*routes, default=None):
+ return Route(*routes, default=default)
def _zero_guard(state, key, ref, dtype):
@@ -112,12 +182,38 @@ def _storage_dtype(group):
_PASSTHROUGH_KWARGS = {"orig_shapes"}
+_RENAMED_KWARGS = {"foreach": "multi_tensor"}
+
+_REMOVED_KWARGS = frozenset(
+ {
+ "stochastic_schedule",
+ "normalize_grads",
+ "correct_bias",
+ "inverse_free",
+ "adaptive",
+ }
+)
+
def _build_defaults(locals_dict):
d = locals_dict.copy()
- d.pop("self")
+ d.pop("self", None)
params = d.pop("params")
kwargs = d.pop("kwargs")
+
+ for old, new in _RENAMED_KWARGS.items():
+ if old in kwargs:
+ warnings.warn(
+ f"'{old}' was renamed to '{new}' in HeavyBall 3.0. Pass '{new}' instead.", FutureWarning, stacklevel=4
+ )
+ d[new] = kwargs.pop(old)
+
+ hit = _REMOVED_KWARGS & kwargs.keys()
+ if hit:
+ raise TypeError(
+ f"Removed in HeavyBall 3.0: {', '.join(sorted(hit))}. See docs/heavyball3.md for migration details."
+ )
+
d.update(kwargs)
unknown = {k: v for k, v in kwargs.items() if k not in _PASSTHROUGH_KWARGS}
if unknown:
@@ -197,12 +293,11 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
if ecc is None:
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
- states = [state(p) for p in param]
names = [self.val_name(n) for n in self.names]
- primary_vars = [[st[vn] for st in states] for vn in names]
+ primary_vars = [[st[vn] for st in state] for vn in names]
with contextlib.ExitStack() as stack:
for vn, plist in zip(names, primary_vars):
- corrs = [st[vn + "::ecc"] for st in states]
+ corrs = [st[vn + "::ecc"] for st in state]
stack.enter_context(ecc.attached(plist, corrs))
return self.fn(state, group, update, grad, param, *args, *primary_vars, **kwargs)
@@ -210,45 +305,50 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
class PrecondGradAccumGuard(FunctionTransform):
def __init__(self, fn):
super().__init__(fn, ["precond_grad_accum"])
- self.steps_taken = 0
- self.pass_through = None
+ self.steps_taken_key = None
+
+ def _build_val_names(self):
+ super()._build_val_names()
+ self.steps_taken_key = f"_{self.fn_name}_steps_taken_{self.transform_idx}"
- def _accum(self, state, new):
- self.steps_taken += 1
+ def _accum(self, group, state, new):
+ group[self.steps_taken_key] = group.get(self.steps_taken_key, 0) + 1
utils.stochastic_add_(state, new)
- def _reset(self, state):
- if self.steps_taken != 0:
- self.steps_taken = 0
+ def _reset(self, group, state):
+ if group.get(self.steps_taken_key, 0) != 0:
+ group[self.steps_taken_key] = 0
utils.zero_(state)
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
- if self.pass_through is None:
- self.pass_through = not group.get("precond_grad_accum", False)
- if self.pass_through is False:
- for name in self.names:
- _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
+ if not group.get("precond_grad_accum", False):
+ return
+ for name in self.names:
+ _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
base_grad = update if group.get("momentum_into_precond_update", True) else grad
- if self.pass_through:
+ if not group.get("precond_grad_accum", False):
return self.fn(state, group, update, grad, param, *args, base_grad, **kwargs)
(vars,) = vars
+ steps_taken = group.get(self.steps_taken_key, 0)
+ accum_state = None
if group["is_preconditioning"]:
- if self.steps_taken:
- self._accum(vars, base_grad)
- utils.stochastic_multiply_(vars, 1 / self.steps_taken)
+ if steps_taken:
+ self._accum(group, vars, base_grad)
+ utils.stochastic_multiply_(vars, 1 / group[self.steps_taken_key])
+ accum_state = vars
else:
vars = base_grad
else:
- self._accum(vars, base_grad)
+ self._accum(group, vars, base_grad)
vars = base_grad
try:
out = self.fn(state, group, update, grad, param, *args, vars, **kwargs)
finally:
- if group["is_preconditioning"]:
- self._reset(vars)
+ if accum_state is not None:
+ self._reset(group, accum_state)
return out
@@ -272,17 +372,6 @@ def __init__(self, fn, names, init_fn, skip_first: bool = True):
super().__init__(fn, names)
self.init_fn = init_fn
self.skip_first = skip_first
- self.named_to_anonymous = None
- self.anonymous_to_named = None
-
- def _map(self, state_fn, param, mapping):
- for p in param:
- state = state_fn(p)
- for name, mapped in mapping.items():
- if mapped in state:
- raise ValueError(f"Name {name} already mapped to {mapped}")
- if name in state:
- state[mapped] = state.pop(name)
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
self.init_fn(state, group, update, grad, param, **kwargs)
@@ -296,12 +385,19 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
class NoState(FunctionTransform):
+ needs_init = False
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
return self.fn(group, update, grad, param, *args, **kwargs)
-class NoStateNoForeach(FunctionTransform):
+class NoStateNoMultiTensor(FunctionTransform):
def __call__(self, state, group, update, grad, param, *args, **kwargs):
+ states = state if isinstance(state, list) else [state(p) for p in param]
+ for st in states:
+ if "is_initialized" not in st:
+ st["is_initialized"] = set()
+ st["is_initialized"].add(self.transform_idx)
updates = []
skip_update = False
for a in zip(update, grad, param, *args):
@@ -324,6 +420,8 @@ def _view_preserve_ecc(src, target):
class SqueezeGrad(FunctionTransform):
+ needs_init = False
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
original_shapes = [u.shape for u in update]
update = [u.squeeze() if u.numel() > 1 else u.view(-1) for u in update]
@@ -353,6 +451,32 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
return self.fn(state, group, update, grad, param, *args, **kwargs)
+class WarmupGuard(FunctionTransform):
+ def __init__(self, fn, warmup_fns):
+ super().__init__(fn, names=[])
+ self.warmup_fns = warmup_fns
+ self.warmup_key = None
+
+ def _build_val_names(self):
+ super()._build_val_names()
+ self.warmup_key = f"_warmup_{self.transform_idx}"
+
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
+ states = state if isinstance(state, list) else [state(p) for p in param]
+ warmup_step = min(st.get(self.warmup_key, 0) for st in states)
+ if warmup_step < len(self.warmup_fns):
+ fn = self.warmup_fns[warmup_step]
+ for st, a in zip(states, zip(update, grad, param, *args)):
+ fn(st, group, *a, **kwargs)
+ st[self.warmup_key] = st.get(self.warmup_key, 0) + 1
+ raise SkipUpdate from None
+ for st in states:
+ if "is_initialized" not in st:
+ st["is_initialized"] = set()
+ st["is_initialized"].add(self.transform_idx)
+ return self.fn(state, group, update, grad, param, *args, **kwargs)
+
+
needs_full_param = functools.partial(TagGuard, needs_full_param=True)
@@ -368,12 +492,16 @@ def general_guard(*names, init_fn, skip_first: bool = True):
return functools.partial(GeneralGuard, names=names, init_fn=init_fn, skip_first=skip_first)
+def warmup_guard(*warmup_fns):
+ return functools.partial(WarmupGuard, warmup_fns=list(warmup_fns))
+
+
def no_state(fn):
return NoState(fn)
-def no_state_no_foreach(fn):
- return NoStateNoForeach(fn)
+def no_state_no_multi_tensor(fn):
+ return NoStateNoMultiTensor(fn)
class SkipUpdate(ValueError):
@@ -405,6 +533,12 @@ def identity(state, group, update, grad, param):
return update
+@no_state
+def apply_update(group, update, grad, param):
+ utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
+ raise SkipUpdate from None
+
+
@zero_guard("exp_avg")
@no_state
def weight_decay_to_ema(group, update, grad, param, exp_avg):
@@ -619,14 +753,14 @@ def orthogonalize_grad_to_param(group, update, grad, param):
@copy_guard(2, "z")
@no_state
def update_by_schedule_free(group, update, grad, param, z):
- # Compute weight_sum once per step, not per param in no-foreach mode.
- if group.get("_sf_step") != group["step"]:
- weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"]
+ # Compute weight_sum once per step, not per param in no-multi_tensor mode.
+ if group.get("_sf_step") is not group["step"]:
+ weight = abs(group["lr"]) ** group["weight_lr_power"] * group["step"].clamp(min=1) ** group["r"]
group["weight_sum"] = group.get("weight_sum", 0) + weight
group["_sf_step"] = group["step"]
weight_sum = group["weight_sum"]
- weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"]
+ weight = abs(group["lr"]) ** group["weight_lr_power"] * group["step"].clamp(min=1) ** group["r"]
try:
ckp1 = weight / weight_sum
except ZeroDivisionError:
@@ -658,28 +792,23 @@ def update_by_msam(group, update, grad, param, z, exp_avg):
raise SkipUpdate from None
+def _adopt_warmup_1(state, group, update, grad, param, exp_avg, exp_avg_sq):
+ utils.scale_by_exp_avg_sq_([exp_avg_sq], [update], 0, group["eps"])
+
+
+def _adopt_warmup_2(state, group, update, grad, param, exp_avg, exp_avg_sq):
+ u = utils.promote(update)
+ easq = utils.promote(exp_avg_sq)
+ utils.copy_stochastic_(exp_avg, u / easq.sqrt().clamp_(min=group["eps"]))
+ utils.scale_by_exp_avg_sq_(
+ [exp_avg_sq], [update], utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"]
+ )
+
+
@zero_guard("exp_avg", "exp_avg_sq")
+@warmup_guard(_adopt_warmup_1, _adopt_warmup_2)
@no_state
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
- if group["step"] == 1:
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
- raise SkipUpdate from None
-
- if group["step"] == 2:
- update = utils.promote(update)
- easq = utils.promote(exp_avg_sq)
- [
- utils.copy_stochastic_(ea, u / easq_.sqrt().clamp_(min=group["eps"]))
- for ea, u, easq_ in zip(exp_avg, update, easq)
- ]
- utils.scale_by_exp_avg_sq_(
- exp_avg_sq,
- update,
- utils.beta_debias(utils.get_beta2(group), group["step"]),
- group["eps"],
- )
- raise SkipUpdate from None
-
utils.fused_adopt_(
param,
update,
@@ -697,14 +826,15 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
raise SkipUpdate from None
+def _suds_warmup_1(state, group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
+ utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
+
+
@needs_full_param
@zero_guard("exp_avg", "exp_avg_sq", "fisher_approx")
-@no_state_no_foreach
+@warmup_guard(_suds_warmup_1)
+@no_state_no_multi_tensor
def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
- if group["step"] == 1:
- utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
- raise SkipUpdate from None
-
precond_update, w = utils.eigvecs_product_rank1(update.flatten(), fisher_approx.flatten().to(update.dtype))
precond_update = utils.adam_(
exp_avg,
@@ -737,27 +867,9 @@ def scale_by_unscaled_adam(group, update, grad, param, exp_avg, exp_avg_sq):
@zero_guard("exp_avg", "exp_avg_sq")
+@warmup_guard(_adopt_warmup_1, _adopt_warmup_2)
@no_state
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
- if group["step"] == 1:
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
- raise SkipUpdate from None
-
- if group["step"] == 2:
- update = utils.promote(update)
- easq = utils.promote(exp_avg_sq)
- [
- utils.copy_stochastic_(ea, u / easq_.sqrt().clamp_(min=group["eps"]))
- for ea, u, easq_ in zip(exp_avg, update, easq)
- ]
- utils.scale_by_exp_avg_sq_(
- exp_avg_sq,
- update,
- utils.beta_debias(utils.get_beta2(group), group["step"]),
- group["eps"],
- )
- raise SkipUpdate from None
-
return utils.adopt(
update,
exp_avg_sq,
@@ -769,6 +881,7 @@ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
+ tmp = utils.get_temporary(group, param) or {}
Q = utils.init_Q_exprs(
grad,
group["precond_init_scale"],
@@ -777,22 +890,66 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
- getattr(param, "hessian_vector", None),
- getattr(param, "vector", None),
+ tmp.get("hessian_vector"),
+ tmp.get("vector"),
dtype=getattr(torch, group["q_dtype"]),
)
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
- state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
- if group["adaptive"]:
- state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64)
if not cached:
return
state["Q_cache"] = [torch.empty_like(q) for q in Q]
+def _init_psgd_eigen_kron(state, group, update, grad, param, prob: Optional[callable] = None):
+ tmp = utils.get_temporary(group, param) or {}
+ Q = utils.init_Q_exprs(
+ grad,
+ group["precond_init_scale"],
+ group["precond_init_scale_scale"],
+ group["precond_init_scale_power"],
+ group["max_size_triangular"],
+ group["min_ndim_triangular"],
+ group["memory_save_mode"],
+ tmp.get("hessian_vector"),
+ tmp.get("vector"),
+ dtype=getattr(torch, group["q_dtype"]),
+ )
+ state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64)
+
+ line, group["store_triu_as_line"] = group["store_triu_as_line"], False
+ _update_psgd_precond(False, None, group, param, update, Q, state["running_lower_bound"], state["step"], prob)
+ group["store_triu_as_line"] = line
+ state["Q"] = utils.triu_to_line(Q) if line else Q
+ state["Q_basis"] = utils.init_psgd_eigenbasis(Q)
+
+
+def _init_psgd_pro_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
+ Q = utils.init_Q_exprs(
+ grad,
+ group["precond_init_scale"],
+ group["precond_init_scale_scale"],
+ group["precond_init_scale_power"],
+ group["max_size_triangular"],
+ group["min_ndim_triangular"],
+ group["memory_save_mode"],
+ None,
+ None,
+ dtype=getattr(torch, group["q_dtype"]),
+ )
+ state["Q"] = Q
+ state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64)
+ if not cached:
+ return
+ state["Q_cache"] = [torch.empty_like(q) for q in Q]
+
+
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
+ tmp = utils.get_temporary(group, param) or {}
state["U"], state["V"], state["d"] = utils.init_lra(
grad,
group["param_count"],
@@ -800,30 +957,14 @@ def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob
group["precond_init_scale_scale"],
group["precond_init_scale_power"],
group["rank"],
- getattr(param, "hessian_vector", None),
- getattr(param, "vector", None),
+ tmp.get("hessian_vector"),
+ tmp.get("vector"),
dtype=getattr(torch, group["q_dtype"]),
)
-def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
- step = group["step"]
- if "precondition_frequency" in group:
- return step > 0 and step % group["precondition_frequency"] == 0
- if isinstance(step, torch.Tensor):
- utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
- rng = random.Random(0x172381)
- else:
- rng = random.Random(0x172381 ^ step)
- if "precond_scheduler" in group:
- return utils.precond_schedule(step, group["precond_scheduler"], rng)
- if prob is not None:
- return utils.psgd_should_update(group, prob, rng, name=name)
- raise ValueError("No preconditioner update schedule specified.")
-
-
@needs_full_param
-@no_state_no_foreach
+@no_state_no_multi_tensor
def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
if update.dim() < 2:
return update
@@ -846,8 +987,20 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
+def _store_init_norm(state, group, update, grad, param):
+ state["init_norm"] = param.to(_storage_dtype(group)).norm()
+
+
+@needs_full_param
+@general_guard("init_norm", init_fn=_store_init_norm, skip_first=False)
+@no_state
+def update_by_hyperball(group, update, grad, param, init_norm):
+ utils.hyperball_step_(param, update, init_norm, group["lr"], group["weight_decay"], group["caution"], grad)
+ raise SkipUpdate from None
+
+
def _store_std(state, group, update, grad, param):
- state["init_std"] = torch.std(param)
+ state["init_std"] = torch.std(param.to(_storage_dtype(group)))
@needs_full_param
@@ -1032,34 +1185,27 @@ def _update_psgd_precond(
param,
grad,
Q,
- velocity,
running_lower_bound,
step,
prob: Optional[callable] = None,
-) -> Optional[Tensor]:
+) -> None:
if prob is None:
prob = utils.precond_update_prob_schedule()
if not group["is_preconditioning"]:
return
- if utils.hasattr_none(param, "vector"):
- vector, hessian_vector = param.vector, param.hessian_vector
- del param.vector
- del param.hessian_vector
- elif group["inverse_free"]:
- vector, hessian_vector = None, grad
- else:
+ if (utils.get_temporary(group, param) or {}).get("vector") is None:
vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
+ else:
+ vector, hessian_vector = utils.take_temporary(group, param, "vector", "hessian_vector")
- precond = utils.psgd_update_precond(
+ utils.psgd_update_precond(
hessian_vector,
group["precond_lr"],
Q,
group["store_triu_as_line"],
- velocity,
utils.get_beta2(group),
- group["ortho_method"],
vector,
running_lower_bound,
group["lower_bound_beta"],
@@ -1073,12 +1219,10 @@ def _update_psgd_precond(
float_prob = prob(group["step"])
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
- if precond is not None:
- return precond
if not should_use_cache or not cached:
- return None # caching adds extra ops and is not worth the overhead when we precondition at every step
+ return
- Q_resolved = utils.line_to_triu(Q, group["inverse_free"]) if group["store_triu_as_line"] else Q
+ Q_resolved = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
for i, (c_, q_) in enumerate(zip(Q_cache, Q_resolved)):
if c_ is None:
c_ = (
@@ -1091,7 +1235,57 @@ def _update_psgd_precond(
torch.matmul(q_.T, q_, out=c_)
else:
torch.mul(q_, q_, out=c_)
- return None
+ return
+
+
+def _update_psgd_pro_precond(
+ cached,
+ Q_cache,
+ group,
+ param,
+ grad,
+ Q,
+ running_lower_bound,
+ step,
+ prob: Optional[callable] = None,
+) -> None:
+ if prob is None:
+ prob = utils.precond_update_prob_schedule()
+
+ if not group["is_preconditioning"]:
+ return
+
+ utils.psgd_pro_update_precond(
+ grad,
+ group["precond_lr"],
+ Q,
+ running_lower_bound,
+ group["lower_bound_beta"],
+ group["precond_update_power_iterations"],
+ group["dampening"],
+ )
+
+ if isinstance(prob, float):
+ float_prob = prob
+ else:
+ float_prob = prob(group["step"])
+ group["is_cached"] = should_use_cache = cached and float_prob < 0.5
+
+ if not should_use_cache or not cached:
+ return
+
+ for i, (c_, q_) in enumerate(zip(Q_cache, Q)):
+ if c_ is None:
+ c_ = (
+ torch.empty_like(q_)
+ if q_.ndim == 1
+ else torch.empty(q_.shape[0], q_.shape[0], device=q_.device, dtype=q_.dtype)
+ )
+ Q_cache[i] = c_
+ if q_.ndim == 2:
+ torch.matmul(q_.T, q_, out=c_)
+ else:
+ torch.mul(q_, q_, out=c_)
def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
@@ -1099,9 +1293,7 @@ def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
if group.get("is_cached", False) and Q_cache[0] is not None:
out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs)
else:
- out = utils.psgd_precond_grad(
- preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
- )
+ out = utils.psgd_precond_grad(preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs)
group["caution"] = False # we already cautioned here - shouldn't do it again
return out
@@ -1118,9 +1310,7 @@ def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache):
if group.get("is_cached", False) and Q_cache[0] is not None:
utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs)
else:
- utils.fused_psgd_precond_grad(
- preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
- )
+ utils.fused_psgd_precond_grad(preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs)
def _update_lra(
@@ -1129,12 +1319,10 @@ def _update_lra(
if not group["is_preconditioning"]:
return utils.multi_flatten((U, 1), (V, 1), (d, 0))
- if utils.hasattr_none(params[0], "hessian_vector"):
- vector = utils.flatten([p.vector for p in params])
- hessian_vector = utils.flatten([p.hessian_vector for p in params])
- for p in params:
- del p.vector
- del p.hessian_vector
+ if (utils.get_temporary(group, params[0]) or {}).get("hessian_vector") is not None:
+ vector_hv = [utils.take_temporary(group, p, "vector", "hessian_vector") for p in params]
+ vector = utils.flatten([v for v, _ in vector_hv])
+ hessian_vector = utils.flatten([hv for _, hv in vector_hv])
else:
vector, hessian_vector = utils.dampen_multiple(grads)
precond_step = group["precond_step"] = group.get("precond_step", -1) + 1
@@ -1188,8 +1376,8 @@ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U,
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@no_state_no_multi_tensor
def scale_by_psgd(
group,
update,
@@ -1198,21 +1386,63 @@ def scale_by_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@zero_guard("exp_avg", "exp_avg_sq")
+@general_guard("Q", "Q_basis", "running_lower_bound", "step", init_fn=_init_psgd_eigen_kron, skip_first=True)
+@no_state_no_multi_tensor
+def scale_by_lather(
+ group,
+ update,
+ grad,
+ param,
+ update_to_precond,
+ exp_avg,
+ exp_avg_sq,
+ Q,
+ Q_basis,
+ running_lower_bound: List[Tensor],
+ step: Tensor,
+ prob: Optional[callable] = None,
+):
+ projected = utils.project(utils.promote(update), Q_basis, False)
+ precond = utils.adam_(
+ exp_avg,
+ exp_avg_sq,
+ projected,
+ utils.get_beta1(group),
+ utils.get_beta2(group),
+ group["step"],
+ group["eps"],
+ )[0]
+ precond = utils.project(precond, Q_basis, True)
+
+ if group["is_preconditioning"]:
+ _update_psgd_precond(False, None, group, param, update_to_precond, Q, running_lower_bound, step, prob)
+ utils.update_psgd_eigenbasis(
+ utils.line_to_triu(Q) if group["store_triu_as_line"] else Q,
+ Q_basis,
+ exp_avg,
+ )
+
+ return precond
+
+
+@needs_full_param
+@SqueezeGrad
+@PrecondGradAccumGuard
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@no_state_no_multi_tensor
def scale_by_delayed_psgd(
group,
update,
@@ -1221,27 +1451,21 @@ def scale_by_delayed_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
- if group.get("inverse_free", False):
- precond = None
- else:
- precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
- new = _update_psgd_precond(
- cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob
- )
- return new if precond is None else precond
+ precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
+ return precond
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@no_state_no_multi_tensor
def update_by_psgd(
group,
update,
@@ -1250,13 +1474,12 @@ def update_by_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
raise SkipUpdate from None
@@ -1276,8 +1499,8 @@ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@no_state_no_multi_tensor
def update_by_delayed_psgd(
group,
update,
@@ -1286,14 +1509,58 @@ def update_by_delayed_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
+ raise SkipUpdate from None
+
+
+@needs_full_param
+@SqueezeGrad
+@PrecondGradAccumGuard
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
+@no_state_no_multi_tensor
+def scale_by_psgd_pro(
+ group,
+ update,
+ grad,
+ param,
+ update_to_precond,
+ Q,
+ Q_cache,
+ running_lower_bound: List[Tensor],
+ step: Tensor,
+ cached: bool = False,
+ prob: Optional[callable] = None,
+):
+ _update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
+ return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
+
+
+@needs_full_param
+@SqueezeGrad
+@PrecondGradAccumGuard
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
+@no_state_no_multi_tensor
+def update_by_psgd_pro(
+ group,
+ update,
+ grad,
+ param,
+ update_to_precond,
+ Q,
+ Q_cache,
+ running_lower_bound: List[Tensor],
+ step: Tensor,
+ cached: bool = False,
+ prob: Optional[callable] = None,
+):
+ _update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
+ _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
raise SkipUpdate from None
@@ -1304,6 +1571,12 @@ def palm_beta2(state, group, update, grad, param):
def apply_to_idx(fn, idx):
+ name = fn
+ if isinstance(fn, str):
+ fn = getattr(utils, fn, None)
+ if fn is None or not callable(fn):
+ raise ValueError(f"Unknown function '{name}'")
+
def _fn(state, group, update, grad, param):
args = [state, group, update, grad, param]
return fn(args[idx])
@@ -1312,15 +1585,70 @@ def _fn(state, group, update, grad, param):
return _fn
+_FSDP_HEADER_WIDTH = 4
+_FSDP_BUCKET_BYTES = 32 << 20
+_FSDP_DTYPE_CODES = {
+ torch.float64: 0,
+ torch.float32: 1,
+ torch.float16: 2,
+ torch.bfloat16: 3,
+ torch.int64: 4,
+ torch.int32: 5,
+ torch.int16: 6,
+ torch.int8: 7,
+ torch.uint8: 8,
+ torch.bool: 9,
+}
+
+
class _ShapeInfo:
- __slots__ = ("orig_shape", "offset", "total", "group", "owner")
+ __slots__ = ("orig_shape", "offset", "total", "group", "owner", "param_idx")
- def __init__(self, orig_shape, offset=0, total=None, group=None, owner=None):
+ def __init__(self, orig_shape, offset=0, total=None, group=None, owner=None, param_idx=None):
self.orig_shape = orig_shape
self.offset = offset
self.total = total if total is not None else math.prod(orig_shape)
self.group = group
self.owner = owner
+ self.param_idx = param_idx
+
+
+class _FSDPBucket:
+ __slots__ = ("device", "dtype", "send_entries", "send_splits", "recv_entries", "recv_splits")
+
+ def __init__(self, device, dtype, send_entries, send_splits, recv_entries, recv_splits):
+ self.device = device
+ self.dtype = dtype
+ self.send_entries = send_entries
+ self.send_splits = send_splits
+ self.recv_entries = recv_entries
+ self.recv_splits = recv_splits
+
+
+class _FSDPState:
+ __slots__ = ("items", "buckets")
+
+ def __init__(self, items, buckets):
+ self.items = items
+ self.buckets = buckets
+
+
+def _dtype_code(dtype):
+ if dtype not in _FSDP_DTYPE_CODES:
+ raise TypeError(f"Unsupported FSDP shard dtype: {dtype}")
+ return _FSDP_DTYPE_CODES[dtype]
+
+
+def _assign_fsdp_owners(entries, shard_sizes, world_size):
+ loads = [0] * world_size
+ owners = []
+ for i, (p, _, total, _) in enumerate(entries):
+ active = shard_sizes[i].nonzero().squeeze(-1).tolist()
+ candidates = active or list(range(world_size))
+ owner = min(candidates, key=loads.__getitem__)
+ loads[owner] += total * p.element_size()
+ owners.append(owner)
+ return owners
def _detect_orig_shapes(params):
@@ -1343,8 +1671,6 @@ def _detect_orig_shapes(params):
for param, spi, shape in zip(obj._params, obj._shard_param_infos, obj._shapes):
lookup[id(param)] = (tuple(shape), spi)
- _dist = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
-
# optimizer param order is stable across ranks
fsdp_entries = [
(p, s, math.prod(s), spi)
@@ -1352,47 +1678,184 @@ def _detect_orig_shapes(params):
for s, spi in [lookup.get(id(p), (None, None))]
if id(p) in fsdp_ids and s is not None
]
-
- groups = {}
- if _dist and fsdp_entries:
- rank = torch.distributed.get_rank()
- ws = torch.distributed.get_world_size()
- n = len(fsdp_entries)
- flags = torch.zeros(n, ws, dtype=torch.int32, device=fsdp_entries[0][0].device)
- for i, (p, orig, total, spi) in enumerate(fsdp_entries):
- if spi.in_shard and spi.numel_in_shard is not None and spi.numel_in_shard < total:
- flags[i, rank] = 1
- torch.distributed.all_reduce(flags)
- pg_cache = {}
- for i in range(n):
- ranks = flags[i].nonzero().squeeze(-1).tolist()
- if len(ranks) >= 2:
- key = tuple(ranks)
- if key not in pg_cache:
- pg_cache[key] = torch.distributed.new_group(ranks)
- groups[i] = (pg_cache[key], ranks)
-
- # owner must be precomputed (different ranks see different subsets in _reshape_params)
result = {}
- split_idx = 0
- for i, (p, orig, total, spi) in enumerate(fsdp_entries):
- sg = groups.get(i)
- if sg:
- pg, ranks = sg
- owner = ranks[split_idx % len(ranks)]
- split_idx += 1
- if spi.in_shard:
- result[id(p)] = _ShapeInfo(orig, spi.intra_param_start_idx, total, pg, owner)
- elif spi.in_shard:
- result[id(p)] = _ShapeInfo(orig, spi.intra_param_start_idx, total)
+ ws = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
+ if ws > 1 and fsdp_entries:
+ rank = torch.distributed.get_rank()
+ shard_sizes = torch.zeros(len(fsdp_entries), ws, dtype=torch.int64, device=fsdp_entries[0][0].device)
+ for i, (p, _, _, spi) in enumerate(fsdp_entries):
+ shard_sizes[i, rank] = p.numel() if spi.in_shard else 0
+ torch.distributed.all_reduce(shard_sizes)
+ owners = _assign_fsdp_owners(fsdp_entries, shard_sizes, ws)
+ else:
+ owners = [None] * len(fsdp_entries)
+ for param_idx, ((p, orig, total, spi), owner) in enumerate(zip(fsdp_entries, owners)):
+ offset = 0 if spi.intra_param_start_idx is None else spi.intra_param_start_idx
+ result[id(p)] = _ShapeInfo(orig, offset, total, owner=owner, param_idx=param_idx)
return result
-def _reduce_gather(shard, offset, total, pg, dst):
- full = shard.new_zeros(total)
- full[offset : offset + shard.numel()] = shard
- torch.distributed.reduce(full, dst=dst, group=pg)
- return full
+def _exchange_split_sizes(splits, device):
+ send = torch.tensor(splits, dtype=torch.int64, device=device)
+ recv = torch.empty_like(send)
+ torch.distributed.all_to_all_single(recv, send)
+ return recv.tolist()
+
+
+def _all_to_all_variable(sendbuf, recv_splits, send_splits):
+ recv = sendbuf.new_empty(sum(recv_splits))
+ torch.distributed.all_to_all_single(recv, sendbuf, output_split_sizes=recv_splits, input_split_sizes=send_splits)
+ return recv
+
+
+def _fsdp_bucket_schedule(items):
+ buckets, current, lookup = [], {}, {}
+ for p, info, _ in items:
+ key = (p.device, p.dtype)
+ idx = current.get(key)
+ size = info.total * p.element_size()
+ if idx is None or (buckets[idx][2] and buckets[idx][2] + size > _FSDP_BUCKET_BYTES):
+ idx = len(buckets)
+ buckets.append([p.device, p.dtype, 0])
+ current[key] = idx
+ buckets[idx][2] += size
+ lookup[info.param_idx] = idx
+ return [(device, dtype) for device, dtype, _ in buckets], lookup
+
+
+def _exchange_fsdp_shards(schedule, bucket_lookup, items, tensor_getter, keep_state=False):
+ ws = torch.distributed.get_world_size()
+ per_bucket = [[] for _ in schedule]
+ for p, info, shard in items:
+ tensor = tensor_getter(p, shard)
+ if tensor is None or tensor.numel() == 0:
+ continue
+ flat = tensor.reshape(-1)
+ bucket_idx = bucket_lookup[info.param_idx]
+ device, dtype = schedule[bucket_idx]
+ if flat.device != device or flat.dtype != dtype:
+ raise RuntimeError(
+ f"FSDP bucket mismatch for param {info.param_idx}: expected {(device, dtype)}, got {(flat.device, flat.dtype)}"
+ )
+ per_bucket[bucket_idx].append((info.owner, info.param_idx, info.offset, flat, shard))
+
+ received, states = {}, []
+ for (device, dtype), bucket_entries in zip(schedule, per_bucket):
+ by_dst = [[] for _ in range(ws)]
+ for entry in bucket_entries:
+ by_dst[entry[0]].append(entry)
+
+ send_meta_splits = [len(dst_entries) * _FSDP_HEADER_WIDTH for dst_entries in by_dst]
+ send_payload_splits = [sum(flat.numel() for _, _, _, flat, _ in dst_entries) for dst_entries in by_dst]
+ recv_meta_splits = _exchange_split_sizes(send_meta_splits, device)
+ recv_payload_splits = _exchange_split_sizes(send_payload_splits, device)
+
+ code = _dtype_code(dtype)
+ meta = [
+ value
+ for dst_entries in by_dst
+ for _, param_idx, offset, flat, _ in dst_entries
+ for value in (param_idx, offset, flat.numel(), code)
+ ]
+ payload = [flat for dst_entries in by_dst for _, _, _, flat, _ in dst_entries]
+ send_meta = (
+ torch.tensor(meta, dtype=torch.int64, device=device)
+ if meta
+ else torch.empty(0, dtype=torch.int64, device=device)
+ )
+ send_payload = torch.cat(payload) if payload else torch.empty(0, dtype=dtype, device=device)
+
+ recv_meta = _all_to_all_variable(send_meta, recv_meta_splits, send_meta_splits)
+ recv_entries = [[] for _ in range(ws)]
+ meta_offset = 0
+ for src, count in enumerate(recv_meta_splits):
+ if count == 0:
+ continue
+ if count % _FSDP_HEADER_WIDTH:
+ raise RuntimeError(f"Malformed FSDP metadata split: {count}")
+ rows = recv_meta[meta_offset : meta_offset + count].view(-1, _FSDP_HEADER_WIDTH).cpu().tolist()
+ meta_offset += count
+ for param_idx, offset, length, got in rows:
+ if got != code:
+ raise RuntimeError(f"FSDP dtype mismatch for bucket {dtype}: expected {code}, got {got}")
+ recv_entries[src].append((param_idx, offset, length))
+
+ recv_payload = _all_to_all_variable(send_payload, recv_payload_splits, send_payload_splits)
+ payload_offset = 0
+ for src_entries in recv_entries:
+ for param_idx, offset, length in src_entries:
+ chunk = recv_payload[payload_offset : payload_offset + length]
+ received.setdefault(param_idx, []).append((offset, chunk))
+ payload_offset += length
+ if payload_offset != recv_payload.numel():
+ raise RuntimeError("FSDP payload unpack mismatch")
+
+ if keep_state:
+ states.append(_FSDPBucket(device, dtype, by_dst, send_payload_splits, recv_entries, recv_payload_splits))
+
+ return received, states
+
+
+def _reshape_fsdp_params(items):
+ rank = torch.distributed.get_rank()
+ schedule, bucket_lookup = _fsdp_bucket_schedule(items)
+ params, buckets = _exchange_fsdp_shards(schedule, bucket_lookup, items, lambda _, shard: shard, keep_state=True)
+ grads, _ = _exchange_fsdp_shards(schedule, bucket_lookup, items, lambda p, _: p.grad)
+
+ for p, info, shard in items:
+ p.grad = None
+ if info.owner != rank:
+ continue
+
+ pieces = params.get(info.param_idx, ())
+ total = sum(chunk.numel() for _, chunk in pieces)
+ if total != info.total:
+ raise RuntimeError(f"FSDP parameter assembly mismatch for param {info.param_idx}: {total} != {info.total}")
+
+ full = shard.new_empty(info.total)
+ for offset, chunk in pieces:
+ full[offset : offset + chunk.numel()].copy_(chunk)
+ p.data = full.view(info.orig_shape)
+
+ grad_pieces = grads.get(info.param_idx, ())
+ if not grad_pieces:
+ continue
+ grad_total = sum(chunk.numel() for _, chunk in grad_pieces)
+ if grad_total != info.total:
+ raise RuntimeError(f"FSDP grad assembly mismatch for param {info.param_idx}: {grad_total} != {info.total}")
+ grad = full.new_empty(info.total, dtype=grad_pieces[0][1].dtype)
+ for offset, chunk in grad_pieces:
+ grad[offset : offset + chunk.numel()].copy_(chunk)
+ p.grad = grad.view(info.orig_shape)
+
+ return _FSDPState(items, buckets)
+
+
+def _restore_fsdp_params(state):
+ by_param = {info.param_idx: (p, info, shard) for p, info, shard in state.items}
+ for bucket in state.buckets:
+ payload = []
+ for dst, recv_entries in enumerate(bucket.recv_entries):
+ for param_idx, offset, length in recv_entries:
+ p, info, _ = by_param[param_idx]
+ flat = p.data.reshape(-1)
+ if flat.numel() != info.total:
+ raise RuntimeError(f"FSDP return path expects full param {param_idx}, got {flat.numel()}")
+ payload.append(flat[offset : offset + length])
+ send_payload = torch.cat(payload) if payload else torch.empty(0, dtype=bucket.dtype, device=bucket.device)
+ recv_payload = _all_to_all_variable(send_payload, bucket.send_splits, bucket.recv_splits)
+
+ payload_offset = 0
+ for send_entries in bucket.send_entries:
+ for _, _, _, flat, shard in send_entries:
+ shard.copy_(recv_payload[payload_offset : payload_offset + flat.numel()])
+ payload_offset += flat.numel()
+ if payload_offset != recv_payload.numel():
+ raise RuntimeError("FSDP return payload unpack mismatch")
+
+ for p, _, shard in state.items:
+ p.data = shard
+ p.grad = None
def _view_param(p, shape):
@@ -1404,57 +1867,55 @@ def _view_param(p, shape):
def _reshape_params(params, orig_shapes, gather=True):
if not orig_shapes:
return [], []
- rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ dist_ready = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
views, gathers = [], []
for p in params:
info = orig_shapes.get(id(p))
- if info is None or p.data.shape == info.orig_shape:
+ if info is None:
continue
- if info.group is not None and gather:
+ if gather and dist_ready and info.owner is not None:
shard = p.data
- full = _reduce_gather(shard, info.offset, info.total, info.group, info.owner)
- if rank == info.owner:
- p.data = full.view(info.orig_shape)
- if p.grad is not None:
- p.grad = _reduce_gather(p.grad, info.offset, info.total, info.group, info.owner).view(
- info.orig_shape
- )
- else:
- del full
- if p.grad is not None:
- _reduce_gather(p.grad, info.offset, info.total, info.group, info.owner)
- p.grad = None
gathers.append((p, info, shard))
+ continue
+
+ if p.data.shape == info.orig_shape:
+ continue
+
+ orig, numel = info.orig_shape, p.data.numel()
+ if numel == info.total:
+ target = orig
+ elif numel > 0 and len(orig) >= 2:
+ inner = math.prod(orig[1:])
+ target = (numel // inner, *orig[1:]) if numel % inner == 0 else None
else:
- orig, numel = info.orig_shape, p.data.numel()
- if numel == info.total:
- target = orig
- elif numel > 0 and len(orig) >= 2:
- inner = math.prod(orig[1:])
- target = (numel // inner, *orig[1:]) if numel % inner == 0 else None
- else:
- continue
- if target is not None:
- flat = p.data.shape
- _view_param(p, target)
- views.append((p, flat))
+ continue
+ if target is not None:
+ flat = p.data.shape
+ _view_param(p, target)
+ views.append((p, flat))
+
+ if gathers:
+ gathers = _reshape_fsdp_params(gathers)
return views, gathers
def _restore_params(views, gathers):
- rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
- for p, info, shard in gathers:
- if rank == info.owner:
- full = p.data.flatten()
- else:
- full = shard.new_empty(info.total)
- torch.distributed.broadcast(full, src=info.owner, group=info.group)
- shard.copy_(full[info.offset : info.offset + shard.numel()])
- p.data = shard
- p.grad = None
+ if isinstance(gathers, _FSDPState):
+ _restore_fsdp_params(gathers)
+ else:
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ for p, info, shard in gathers:
+ if rank == info.owner:
+ full = p.data.flatten()
+ else:
+ full = shard.new_empty(info.total)
+ torch.distributed.broadcast(full, src=info.owner, group=info.group)
+ shard.copy_(full[info.offset : info.offset + shard.numel()])
+ p.data = shard
+ p.grad = None
for p, flat in views:
_view_param(p, flat)
@@ -1472,7 +1933,7 @@ def _inner_chain(state, group, update, grad, param, *fns):
return update, skip_update
-def chain(state: Union[callable, dict], group, grad, param, *fns):
+def chain(state: list, group, grad, param, *fns):
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
ecc = ECCConfig.from_group(group, key="param_ecc")
@@ -1483,8 +1944,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
return
- states = [state(p) for p in param]
- corrs = [st["param::ecc"] for st in states]
+ corrs = [st["param::ecc"] for st in state]
with ecc.attached(param, corrs):
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
if not skip_update and update is not None:
@@ -1500,9 +1960,14 @@ def _walk_fns(obj):
stack.append(cur.fn)
elif isinstance(cur, functools.partial):
stack.append(cur.func)
- elif isinstance(cur, Branch):
+ elif isinstance(cur, Parallel):
for branch in cur.branches:
stack.extend(branch)
+ elif isinstance(cur, Route):
+ for _, fns in cur.routes:
+ stack.extend(fns)
+ if cur.default is not None:
+ stack.extend(cur.default)
elif isinstance(cur, _Iterable) and not isinstance(cur, (str, bytes, bytearray)):
stack.extend(cur)
@@ -1518,6 +1983,7 @@ def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
for ft in _walk_fns(new_fns):
if not retain or ft.transform_idx is None:
ft.transform_idx, offset = offset, offset + 1
+ ft._build_val_names()
return new_fns
@@ -1532,15 +1998,18 @@ class ChainOpt(utils.StatefulOptimizer):
"eps": 1e-8,
}
- def __init__(self, params, defaults, foreach: bool, *fns):
+ def __init__(self, params, defaults, *fns):
orig = defaults.pop("orig_shapes", None)
self._orig_shapes = (
{k: _ShapeInfo(v) if isinstance(v, tuple) else v for k, v in orig.items()} if orig is not None else None
)
base = self.global_defaults.copy()
base.update({k: v for k, v in defaults.items() if v is not use_default})
- super().__init__(params, base, foreach)
+ super().__init__(params, base)
self.fns = fns
+ self._eager_chain = self._run_chain
+ if self.compile_step:
+ self._run_chain = torch.compile(self._run_chain, fullgraph=True)
self.register_load_state_dict_post_hook(ChainOpt._restore_ecc_dtypes)
self._init_param_ecc()
@@ -1631,17 +2100,28 @@ def fns(self, value):
self._fns = value
self._set_indices(retain=True)
self._needs_gather = any(getattr(ft, "needs_full_param", False) for ft in _walk_fns(self._fns))
+ self._transform_ids = frozenset(
+ ft.transform_idx
+ for ft in _walk_fns(self._fns)
+ if ft.transform_idx is not None and getattr(ft, "needs_init", True)
+ )
def _set_indices(self, retain=True):
self._fns = set_indices(self.fns, retain)
+ def _find_val_name(self, name):
+ for ft in _walk_fns(self._fns):
+ if name in ft._val_names:
+ return ft._val_names[name]
+ raise KeyError(f"No transform stores '{name}'")
+
def _step(self, group):
if "base_lr" not in group:
group["base_lr"] = group["lr"]
- if "prev_lr" in group and group["prev_lr"] != group["lr"]:
+ if "base_lr" in group and group["base_lr"] != group["lr"]:
utils.warn_once(
f"Learning rate changed between steps. This is an experimental feature and "
- f"only supported with foreach=True (currently foreach={group['foreach']})."
+ f"only supported with multi_tensor=True (currently multi_tensor={group['multi_tensor']})."
)
group["base_lr"] = group["lr"]
@@ -1664,48 +2144,69 @@ def _step_inner(self, group):
return
p, g = zip(*vals)
- for param in p:
- state = self.state_(param)
- if "step" in state:
- step = state["step"]
- elif self.compile_step:
- step = utils.scalar_guard(0, param)
+ step = group.get("_group_step")
+ if step is None:
+ for param in group["params"]:
+ param_state = self.state.get(param)
+ if not isinstance(param_state, dict):
+ continue
+ for idx_state in param_state.values():
+ if isinstance(idx_state, dict) and "step" in idx_state:
+ step = idx_state["step"]
+ break
+ if step is not None:
+ break
else:
step = 0
- break
-
- group["step"] = state["step"] = step = step + 1
- group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
+ if isinstance(step, torch.Tensor):
+ step = step.to(device=p[0].device, dtype=torch.int64)
+ else:
+ step = utils.scalar_guard(step, p[0])
+ group["_group_step"] = group["step"] = step = step + 1
+ self.state_(p[0])["step"] = step
+ group["prev_lr"] = group["lr"] = group["base_lr"] * step / step.clamp(min=group["warmup_steps"] + 1)
- if not group["foreach"] or len(p) == 1:
+ if not group["multi_tensor"] or len(p) == 1:
for param, grad in zip(p, g):
- chain(self.state_, group, [grad], [param], *self.fns)
- group["caution"] = caution
+ self._chain(group, [grad], [param], caution)
else:
- chain(self.state_, group, g, p, *self.fns)
+ self._chain(group, g, p, caution)
group["caution"] = caution
- group["lr"] = group["prev_lr"]
+ group["lr"] = group["base_lr"]
group["step"] = None
+ def _run_chain(self, state, group, g, p, caution):
+ chain(state, group, g, p, *self.fns)
+ group["caution"] = caution
-str_or_fn = Union[str, callable, None, Literal[use_default]]
+ def _needs_init(self, state):
+ ids = self._transform_ids
+ if not ids:
+ return False
+ all_initialized = set()
+ for st in state:
+ all_initialized.update(st.get("is_initialized", ()))
+ return not ids.issubset(all_initialized)
+
+ def _needs_eager(self, group, state):
+ if self._needs_init(state):
+ return True
+ if group.get("is_preconditioning", False):
+ return True
+ if group.get("ecc") or group.get("param_ecc"):
+ return True
+ return False
+
+ def _chain(self, group, g, p, caution):
+ state = [self.state_(pi) for pi in p]
+ fn = self._run_chain
+ if self.compile_step and self._needs_eager(group, state):
+ fn = self._eager_chain
+ fn(state, group, g, p, caution)
-def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
- name = default(name, default_val)
- if callable(name):
- return name
- elif name not in (
- "l2_clip_",
- "rmsnorm_clip_",
- "trust_region_clip_",
- "a_law_compress",
- "mu_law_compress",
- "softsign_compress",
- ):
- raise ValueError(f"Clipping function {name} not found")
- return getattr(utils, name)
+str_or_fn = Union[str, callable, None, Literal[use_default]]
def default(a, b):
@@ -1723,6 +2224,7 @@ def default(a, b):
scale_by_laprop.get_fn(): update_by_laprop, #
scale_by_adopt.get_fn(): update_by_adopt, #
scale_by_ademamix.get_fn(): update_by_ademamix, #
+ scale_by_psgd_pro.get_fn(): update_by_psgd_pro, #
}
_scale_to_update_map_inv = {
update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
@@ -1734,6 +2236,7 @@ def default(a, b):
update_by_laprop.get_fn(): scale_by_laprop, #
update_by_adopt.get_fn(): scale_by_adopt, #
update_by_ademamix.get_fn(): scale_by_ademamix, #
+ update_by_psgd_pro.get_fn(): scale_by_psgd_pro, #
}
@@ -1742,23 +2245,17 @@ class BaseOpt(ChainOpt):
Base Optimizer
compile_step: bool = False
- Whether to change some internals to try to make the optimizer compilable
- This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
+ Whether to torch.compile the optimizer step (fullgraph=True).
+ Initialization runs eagerly on the first step; subsequent steps are compiled.
promote: bool = False
- Whether to promote the gradients to fp32 before applying the optimizer
- Improves update quality for low-precision parameters, but increases costs
- Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
+ Whether to promote the gradients to fp32 before applying the optimizer.
gradient_clipping: str_or_fn = None
- The function to use for clipping the incoming gradients, before any other transformations.
- This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
+ Clipping function applied to incoming gradients before any other transforms.
update_clipping: str_or_fn = None
- The function to use for clipping the outgoing updates before applying them, after all other transformations.
- This will turn off fused updates.
- This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
-
+ Clipping function applied to outgoing updates. Disables fused updates.
"""
gradient_clipping: str_or_fn = None
@@ -1770,7 +2267,6 @@ def __init__(
self,
params,
defaults,
- foreach: bool = True,
gradient_clipping: str_or_fn = None,
update_clipping: str_or_fn = None,
palm: bool = use_default,
@@ -1818,57 +2314,49 @@ def __init__(
if default(update_clipping, self.update_clipping) is not None:
fns = fns + (apply_to_idx(update_clipping, 2),)
- super().__init__(params, defaults, foreach, *fns)
+ super().__init__(params, defaults, *fns)
class ScheduleFree(BaseOpt):
def eval(self):
+ return self.train(False)
+
+ def train(self, mode: bool = True):
+ z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
- beta1 = utils.get_beta1(group)
- if beta1 > 0 and not train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if "z" in state:
- # Set p.data to x
- z = utils.promote(state["z"])
- p32 = utils.promote(p.data)
- p32.lerp_(end=z, weight=1 - 1 / beta1)
- utils.copy_stochastic_(p.data, p32)
-
- def train(self):
- for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
+ train_mode = group.get("train_mode", True)
+ if train_mode == mode:
+ continue
+ group["train_mode"] = mode
beta1 = utils.get_beta1(group)
- if beta1 > 0 and train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if "z" in state:
- z = utils.promote(state["z"])
- p32 = utils.promote(p.data)
- p32.lerp_(end=z, weight=1 - beta1)
- utils.copy_stochastic_(p.data, p32)
+ if beta1 <= 0:
+ continue
+ weight = 1 - beta1 if mode else 1 - 1 / beta1
+ for p in group["params"]:
+ state = self.state_(p)
+ if z_key in state:
+ z = utils.promote(state[z_key])
+ p32 = utils.promote(p.data)
+ p32.lerp_(end=z, weight=weight)
+ utils.copy_stochastic_(p.data, p32)
+ return self
class MSAM(BaseOpt):
def eval(self):
+ return self.train(False)
+
+ def train(self, mode: bool = True):
+ z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
- if not train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if "z" in state:
- p_copy = p.data.clone()
- utils.copy_stochastic_(p.data, state["z"])
- utils.copy_stochastic_(state["z"], p_copy)
-
- def train(self):
- for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
- if train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if "z" in state:
- p_copy = p.data.clone()
- utils.copy_stochastic_(p.data, state["z"])
- utils.copy_stochastic_(state["z"], p_copy)
+ train_mode = group.get("train_mode", True)
+ if train_mode == mode:
+ continue
+ group["train_mode"] = mode
+ for p in group["params"]:
+ state = self.state_(p)
+ if z_key in state:
+ p_copy = p.data.clone()
+ utils.copy_stochastic_(p.data, state[z_key])
+ utils.copy_stochastic_(state[z_key], p_copy)
+ return self
diff --git a/heavyball/helpers.py b/heavyball/helpers.py
index 23a4e77..71a2cdb 100644
--- a/heavyball/helpers.py
+++ b/heavyball/helpers.py
@@ -178,7 +178,6 @@ class BoTorchSampler(SimpleAPIBaseSampler):
"""
A significantly more efficient implementation of `BoTorchSampler` from Optuna - keeps more on the GPU / in torch
The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
- The original API is kept for backward compatibility, but many arguments are ignored to improve maintainability.
"""
def __init__(
@@ -186,18 +185,12 @@ def __init__(
search_space: Optional[dict[str, BaseDistribution]] = None,
*,
candidates_func: Optional[Callable[..., Tensor]] = None,
- constraints_func: Optional[Callable[..., Tensor]] = None,
n_startup_trials: int = 10,
- consider_running_trials: bool = False,
independent_sampler: Optional[BaseSampler] = None,
seed: int | None = None,
device: torch.device | str | None = None,
trial_chunks: int = 128,
):
- if constraints_func is not None:
- raise NotImplementedError("constraints_func is currently not supported by BoTorchSampler.")
- if consider_running_trials:
- raise NotImplementedError("consider_running_trials is currently not supported by BoTorchSampler.")
if candidates_func is not None and not callable(candidates_func):
raise TypeError("candidates_func must be callable.")
self._candidates_func = candidates_func
@@ -206,7 +199,6 @@ def __init__(
self._seed = seed
self.trial_chunks = trial_chunks
- self._study_id: int | None = None
self.search_space = {} if search_space is None else dict(search_space)
if isinstance(device, str):
device = torch.device(device)
@@ -643,12 +635,9 @@ def __init__(
search_space: dict[str, BaseDistribution],
*,
seed: int | None = None,
- constant_liar: bool = False,
independent_sampler: BaseSampler | None = None,
) -> None:
super().__init__(search_space, seed)
- if constant_liar:
- raise NotImplementedError("constant_liar is not supported by HEBOSampler.")
self._hebo = HEBO(_convert_to_hebo_design_space(search_space), scramble_seed=self._seed)
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
self._rng = np.random.default_rng(seed)
@@ -716,7 +705,6 @@ def __init__(
upper: np.ndarray,
seed: Optional[int] = None,
population_size: Optional[int] = None,
- learning_rate: Optional[float] = None,
last_n: int = 4096,
loco_step_size: float = 0.1,
device: str | None = None,
@@ -733,7 +721,6 @@ def __init__(
self.last_n = last_n
self.batchnorm_decay = batchnorm_decay
self.score_decay = score_decay
- self._learning_rate = learning_rate or 1.0 / np.sqrt(n_dimension)
self._mean = torch.from_numpy(mean).to(device)
self._sigma = torch.from_numpy(inv_sigma).to(device)
self._lower = torch.from_numpy(lower).to(device)
@@ -832,20 +819,16 @@ def __init__(
search_space: Dict[str, BaseDistribution],
x0: Optional[Dict[str, Any]] = None,
sigma0: Optional[float] = None,
- lr: Optional[float] = None,
n_startup_trials: int = 1,
independent_sampler: Optional[BaseSampler] = None,
- warn_independent_sampling: bool = True,
seed: Optional[int] = None,
population_size: Optional[int] = None,
) -> None:
self.search_space = search_space
self._x0 = x0
self._sigma0 = sigma0
- self._lr = lr
self._independent_sampler = independent_sampler or optuna.samplers.RandomSampler(seed=seed)
self._n_startup_trials = n_startup_trials
- self._warn_independent_sampling = warn_independent_sampling
self._optimizer: Optional[FastINGO] = None
self._seed = seed
self._population_size = population_size
@@ -906,7 +889,6 @@ def sample_relative(
return {}
if len(search_space) == 1:
- self._warn_independent_sampling = False
return {}
trans = SearchSpaceTransform(search_space)
@@ -949,7 +931,6 @@ def _init_optimizer(
upper=upper_bounds,
seed=self._seed,
population_size=population_size,
- learning_rate=self._lr,
)
def sample_independent(
@@ -1022,10 +1003,7 @@ def __init__(
search_space: Optional[dict[str, BaseDistribution]] = None,
*,
seed: int | None = None,
- constraints_func: Optional[Callable[..., Any]] = None,
) -> None:
- if constraints_func is not None:
- raise NotImplementedError("constraints_func is not supported by AutoSampler.")
if samplers is None:
if search_space is None:
raise ValueError("AutoSampler requires a search_space when using the default sampler schedule.")
@@ -1036,7 +1014,6 @@ def __init__(
self._rng = LazyRandomState(seed)
self._random_sampler = RandomSampler(seed=seed)
self._thread_local_sampler = ThreadLocalSampler()
- self._constraints_func = constraints_func
self._completed_trials = 0
self._current_index = -1
diff --git a/heavyball/utils.py b/heavyball/utils.py
index f935fca..0101df7 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -7,11 +7,10 @@
import itertools
import math
import pickle
-import random
import re
import string
import warnings
-from typing import Any, Callable, List, Literal, Optional, Tuple, Union
+from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -107,8 +106,19 @@ def _fn(*args, **kwargs):
return _fn
+def decorator_no_fullgraph(func: Callable):
+ return decorator_knowngood(func, fullgraph=False)
+
+
einsum_base = string.ascii_lowercase
+no_compile_qr = torch.compiler.disable(torch.linalg.qr)
+no_compile_eigh = torch.compiler.disable(torch.linalg.eigh)
+no_compile_solve = torch.compiler.disable(torch.linalg.solve)
+no_compile_svd = torch.compiler.disable(torch.linalg.svd)
+no_compile_lobpcg = torch.compiler.disable(torch.lobpcg)
+no_compile_solve_triangular = torch.compiler.disable(torch.linalg.solve_triangular)
+
@decorator_knowngood
def compiled_einsum(expr, *args):
@@ -657,7 +667,7 @@ def _scion_bias_rms_direction(x: Tensor, eps: float = 1e-8) -> Tensor:
def _scion_spectral_direction(x: Tensor) -> Tensor:
flat = x.reshape(x.shape[0], -1)
- inplace_orthogonal_(flat)
+ flat = inplace_orthogonal_(flat)
normalized = flat.reshape_as(x)
in_dim = max(flat.shape[1], 1)
scale = math.sqrt(x.shape[0] / in_dim)
@@ -666,7 +676,7 @@ def _scion_spectral_direction(x: Tensor) -> Tensor:
def _scion_spectral_conv_direction(x: Tensor) -> Tensor:
flat = x.reshape(x.shape[0], -1)
- inplace_orthogonal_(flat)
+ flat = inplace_orthogonal_(flat)
normalized = flat.reshape_as(x)
out_channels, in_channels = x.shape[:2]
spatial = math.prod(x.shape[2:]) if x.ndim > 2 else 1
@@ -771,11 +781,11 @@ def _compilable_grafting(magnitude, direction):
return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
-@decorator_knowngood
+@decorator_no_fullgraph
def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
if not isinstance(mode, ZerothPowerMode):
mode = ZerothPowerMode(mode)
- if not isinstance(scale_mode, ZerothPowerMode):
+ if not isinstance(scale_mode, OrthoScaleMode):
scale_mode = OrthoScaleMode(scale_mode)
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
y = zeropower_via_newtonschulz5(x, 5)
@@ -784,12 +794,12 @@ def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor
elif mode == ZerothPowerMode.legacy_newtonschulz:
y = legacy_zeropower_via_newtonschulz5(x, 5)
elif mode == ZerothPowerMode.qr:
- y = torch.linalg.qr(promote(x)).Q
+ y = no_compile_qr(promote(x)).Q
elif mode == ZerothPowerMode.svd:
- u, _s, vt = torch.linalg.svd(promote(x))
+ u, _s, vt = no_compile_svd(promote(x))
y = u @ vt
elif mode == ZerothPowerMode.legacy_svd:
- u, _s, vt = torch.linalg.svd(promote(x))
+ u, _s, vt = no_compile_svd(promote(x))
y = u @ vt.T
else:
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
@@ -816,7 +826,7 @@ def _compilable_scatter_set(target, source, index):
target[:] = source.contiguous()[index].reshape_as(target)
-# @decorator_knowngood
+@decorator_no_fullgraph
def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor):
"""
Computes the eigenbases of the preconditioner using one round of power iteration
@@ -885,6 +895,126 @@ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], *exp_avg: Tensor
copy_stochastic_(q, q_new)
+def _transform_projected_state(old_qs: List[Optional[Tensor]], new_qs: List[Optional[Tensor]], *states: Tensor):
+ if not states:
+ return
+
+ ref = states[0]
+ if ref is None or ref.dim() == 0:
+ return
+
+ assert ref.ndim < 13, "ref.ndim must be less than 13"
+ in_str = einsum_base[: ref.dim()]
+ out_str = einsum_base[ref.dim() : 2 * ref.dim()]
+
+ old_basis = ",".join([o + i for q, i, o in zip(old_qs, in_str, in_str.upper()) if q is not None])
+ if not old_basis:
+ return
+
+ new_basis = ",".join([i + o for q, i, o in zip(new_qs, in_str.upper(), out_str) if q is not None])
+ out_str = "".join([o if o in new_basis else i for i, o in zip(in_str, out_str)])
+ subscripts = f"{in_str},{old_basis},{new_basis}->{out_str}"
+ old_basis = [promote(q) for q in old_qs if q is not None]
+ new_basis = [promote(q) for q in new_qs if q is not None]
+
+ for state in states:
+ new = compiled_einsum(subscripts, promote(state), *old_basis, *new_basis)
+ copy_stochastic_(state, new)
+
+
+@decorator_no_fullgraph
+def init_psgd_eigenbasis(Q: List[Tensor]):
+ out = []
+
+ for q in Q:
+ if q.ndim < 2:
+ out.append(None)
+ continue
+
+ q32 = promote(q)
+ out.append(_stable_symmetric_basis(q32.mT @ q32, out_device=q.device, out_dtype=q.dtype))
+
+ return out
+
+
+@decorator_no_fullgraph
+def get_psgd_eigenbasis(Q: List[Tensor], prev: List[Optional[Tensor]]):
+ out = []
+
+ for q, old_basis in zip(Q, prev):
+ if q.ndim < 2:
+ out.append(None)
+ continue
+ if old_basis is None:
+ raise ValueError(
+ "get_psgd_eigenbasis requires a previous basis for matrix blocks; use init_psgd_eigenbasis"
+ )
+
+ q32 = promote(q)
+ old_basis32 = promote(old_basis)
+ Y = q32.mT @ (q32 @ old_basis32)
+ basis_raw = no_compile_qr(Y, mode="reduced").Q.to(dtype=q.dtype)
+ projected = q32 @ promote(basis_raw)
+ sort_idx = torch.argsort(compiled_einsum("ij,ij->j", projected, projected), descending=True)
+ basis_raw = basis_raw.index_select(1, sort_idx)
+ signs = compiled_einsum("ij,ij->j", old_basis32, promote(basis_raw))
+ signs = torch.where(signs < 0, -torch.ones_like(signs), torch.ones_like(signs)).to(dtype=basis_raw.dtype)
+ basis = basis_raw * signs.view(1, -1)
+ out.append(basis)
+
+ return out
+
+
+@decorator_no_fullgraph
+def update_psgd_eigenbasis(Q: List[Tensor], Q_basis: List[Tensor], *states: Tensor):
+ new_basis = get_psgd_eigenbasis(Q, Q_basis)
+ _transform_projected_state(Q_basis, new_basis, *states)
+
+ for i, (old_basis, new_basis_i) in enumerate(zip(Q_basis, new_basis)):
+ if old_basis is None: # happens only if ndim < 2
+ continue
+ copy_stochastic_(old_basis, new_basis_i)
+
+
+def _stable_symmetric_basis(
+ m: Tensor,
+ max_eps: float = 1e-3,
+ min_eps: float = 1e-30,
+ *,
+ out_device=None,
+ out_dtype=None,
+):
+ out_device = m.device if out_device is None else out_device
+ out_dtype = m.dtype if out_dtype is None else out_dtype
+ m = promote(m.data)
+
+ eps = min_eps
+ while True:
+ try:
+ eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
+ _eigval, eigvec = no_compile_eigh(m + eps * eye)
+ return torch.flip(eigvec, [1]).to(device=out_device, dtype=out_dtype)
+ except torch.OutOfMemoryError:
+ if m.device.type == "cpu":
+ raise
+ if torch.cuda.is_available():
+ torch.cuda.synchronize(m.device)
+ clean()
+ m = m.cpu()
+ except RuntimeError as e:
+ if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
+ torch.cuda.synchronize(m.device)
+ clean()
+ m = m.cpu()
+ elif m.dtype != torch.double:
+ m = m.double()
+ elif eps < max_eps:
+ eps = eps ** (2 / 3)
+ else:
+ raise
+ clean()
+
+
def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
"""
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
@@ -896,38 +1026,7 @@ def get_orthogonal_matrix(mat, max_eps: float = 1e-3, min_eps: float = 1e-30):
final.append(None)
continue
- device, dtype = m.device, m.dtype
- m = promote(m.data)
-
- eps = min_eps
- while True:
- try:
- eye = torch.eye(m.shape[0], device=m.device, dtype=m.dtype)
- _eigval, eigvec = torch.linalg.eigh(m + eps * eye)
- eigvec = eigvec.to(device=device, dtype=dtype)
- break
- except torch.OutOfMemoryError:
- if m.device.type == "cpu":
- raise
- if torch.cuda.is_available():
- torch.cuda.synchronize(m.device)
- clean()
- m = m.cpu()
- except RuntimeError as e:
- if torch.cuda.is_available() and ("CUDA" in str(e) or "illegal memory access" in str(e)):
- torch.cuda.synchronize(m.device)
- clean()
- m = m.cpu()
- elif m.dtype != torch.double:
- m = m.double()
- elif eps < max_eps:
- eps = eps ** (2 / 3)
- else:
- raise
- clean()
-
- eigvec = torch.flip(eigvec, [1])
- final.append(eigvec)
+ final.append(_stable_symmetric_basis(m, max_eps=max_eps, min_eps=min_eps))
return final
@@ -988,6 +1087,8 @@ def scalar_guard(*args):
out.append(torch.empty((), dtype=promote(ref.dtype), device=ref.device).fill_(x))
elif isinstance(x, int):
out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
+ elif isinstance(x, Tensor) and x.is_floating_point() and x.ndim == 0:
+ out.append(x.to(dtype=promote(ref.dtype)))
else:
out.append(x)
if len(xs) == 1:
@@ -1268,6 +1369,28 @@ def hasattr_none(obj, name):
return getattr(obj, name, None) is not None
+def set_temporary(group: dict, tensor: Tensor, **kwargs):
+ if not kwargs:
+ return
+ state = group.setdefault("_tmp", {}).setdefault(id(tensor), {"tensor": tensor})
+ state.update(kwargs)
+
+
+def get_temporary(group: dict, tensor: Tensor):
+ tmp = group.get("_tmp")
+ return None if tmp is None else tmp.get(id(tensor))
+
+
+def take_temporary(group: dict, tensor: Tensor, *keys):
+ state = get_temporary(group, tensor)
+ if state is None:
+ return None if len(keys) == 1 else (None,) * len(keys)
+ out = tuple(state.pop(key, None) for key in keys)
+ if len(state) == 1:
+ group["_tmp"].pop(id(tensor), None)
+ return out[0] if len(keys) == 1 else out
+
+
class ExactHVPFailed(ValueError):
pass
@@ -1291,11 +1414,11 @@ class StatefulOptimizer(torch.optim.Optimizer):
compile_step: bool = False
hessian_approx: bool = False
precond_schedule: Union[Callable, float, None] = None
- stochastic_schedule: bool | Literal[use_default] = use_default
finite_differences: bool = False
fallback_to_finite_differences: bool = True
_fallback_enabled: bool = False
hvp_interval: int = 1 # grad is faster initially, hvp later
+ consume_grad: bool = True
_INSTANCE_ATTRS = (
"compile_step",
@@ -1303,30 +1426,22 @@ class StatefulOptimizer(torch.optim.Optimizer):
"fallback_to_finite_differences",
"hvp_interval",
"hessian_approx",
+ "consume_grad",
)
- def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
+ def __init__(self, params, defaults, use_ema: bool = False):
for attr in self._INSTANCE_ATTRS:
if attr in defaults:
val = defaults.pop(attr)
if val is not use_default:
setattr(self, attr, val)
- super().__init__(params, {**defaults, "foreach": foreach})
+ defaults.setdefault("multi_tensor", True)
+ super().__init__(params, defaults)
self.use_ema = use_ema
self.mapping = {}
self.mapping_inverse = {}
- if self.stochastic_schedule is use_default:
- stochastic_schedule = None
- for group in self.param_groups:
- new = group.get("stochastic_schedule", stochastic_schedule)
- if stochastic_schedule is not None and new != stochastic_schedule:
- raise ValueError("All parameter groups must have the same stochastic_schedule.")
- stochastic_schedule = new
- self.stochastic_schedule = stochastic_schedule
-
- self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
- self.precond_rng = random.Random(0x12312)
+ self.inner_group = {}
self._is_preconditioning = None
if self.hessian_approx and self.compile_step:
@@ -1338,22 +1453,24 @@ def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False
def _store_stats(self, state_dict: dict[str, any]):
state_dict["heavyball"] = {
"inner_group": self.inner_group,
- "precond_rng": pickle.dumps(self.precond_rng),
"use_ema": self.use_ema,
"ema_decay": self.ema_decay,
"compile_step": self.compile_step,
"hessian_approx": self.hessian_approx,
"precond_schedule": pickle.dumps(self.precond_schedule),
- "stochastic_schedule": self.stochastic_schedule,
"fallback_to_finite_differences": self.fallback_to_finite_differences,
"_fallback_enabled": self._fallback_enabled,
"hvp_interval": self.hvp_interval,
}
+ _REMOVED_STATS = frozenset({"stochastic_schedule", "precond_rng"})
+
def _load_stats(self, state_dict):
sd = state_dict.pop("heavyball", {})
for k, v in sd.items():
- if k in ("precond_rng", "precond_schedule"):
+ if k in self._REMOVED_STATS:
+ continue
+ if k in ("precond_schedule",):
v = pickle.loads(v)
setattr(self, k, v)
@@ -1392,21 +1509,31 @@ def split_p_and_g_in_group(
should_promote: bool = True,
raw: bool = False,
):
+ tmp = group.get("_tmp")
for p in group["params"]:
- grad = getattr(p, "grad", None)
- if grad is None and skip_none:
- continue
-
- p.grad = None
-
if raw:
+ grad = getattr(p, "grad", None)
+ if grad is None and skip_none:
+ continue
+ if self.consume_grad:
+ p.grad = None
yield p, grad
continue
+ state = None if tmp is None else tmp.get(id(p))
+ grad = None if state is None else state.pop("grad", None)
+ if grad is None:
+ grad = getattr(p, "grad", None)
+ if grad is None and skip_none:
+ continue
+
+ if self.consume_grad:
+ p.grad = None
+
if group.get("merge_dims", False) and not p.data.is_contiguous():
for fmt in (torch.channels_last, torch.channels_last_3d):
if p.data.is_contiguous(memory_format=fmt):
- p._restore_memory_format = fmt
+ set_temporary(group, p, restore_memory_format=fmt)
break
p.data = p.data.contiguous()
@@ -1414,20 +1541,23 @@ def split_p_and_g_in_group(
for i, pv in enumerate(p_views):
self.mapping_inverse[_tensor_key(pv)] = (p, i)
- vector = getattr(p, "vector", None)
- hessian_vector = getattr(p, "hessian_vector", None)
- p.vector = None
- p.hessian_vector = None
-
- grad, vs, hvs = [
- [None] * len(p_views) if x is None else merge_group(group, x) #
- for x in (grad, vector, hessian_vector)
- ]
+ if state is None:
+ vector = hessian_vector = None
+ else:
+ vector = state.pop("vector", None)
+ hessian_vector = state.pop("hessian_vector", None)
+ if len(state) == 1:
+ tmp.pop(id(p), None)
+ grad = itertools.repeat(None, len(p_views)) if grad is None else merge_group(group, grad)
+ vs = itertools.repeat(None, len(p_views)) if vector is None else merge_group(group, vector)
+ hvs = itertools.repeat(None, len(p_views)) if hessian_vector is None else merge_group(group, hessian_vector)
for pv, g, v, hv in zip(p_views, grad, vs, hvs):
g = promote_detach(g, should_promote)
- pv.vector = promote_detach(v, should_promote)
- pv.hessian_vector = promote_detach(hv, should_promote)
+ v = promote_detach(v, should_promote)
+ hv = promote_detach(hv, should_promote)
+ if v is not None or hv is not None:
+ set_temporary(group, pv, vector=v, hessian_vector=hv)
yield pv, g
def state_size(self) -> int:
@@ -1501,10 +1631,10 @@ def _finite_differences_hvp(self, closure):
for group in self.param_groups:
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
grads.append(g)
- p.vector = torch.randn_like(p)
- p.orig = p.data.clone()
+ vector = torch.randn_like(p)
+ set_temporary(group, p, vector=vector, orig=p.data.clone())
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
- stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
+ stochastic_add_(p.data, vector, torch.finfo(torch.float32).eps ** 0.5)
with torch.enable_grad():
closure()
@@ -1513,22 +1643,22 @@ def _finite_differences_hvp(self, closure):
# this costs more memory, but the imprecision seems too severe to use the other method
for group in self.param_groups:
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
- p.grad = grads.pop(0)
- stochastic_add_divide_(g, p.grad, -1, torch.finfo(p.dtype).eps ** 0.5)
- p.hessian_vector = g
- p.data.copy_(p.orig)
- del p.orig
+ grad = grads.pop(0)
+ set_temporary(group, p, grad=grad, hessian_vector=g)
+ stochastic_add_divide_(g, grad, -1, torch.finfo(torch.float32).eps ** 0.5)
+ p.data.copy_(take_temporary(group, p, "orig"))
return loss
def _double_backward_hvp(self, closure):
with torch.enable_grad(), patch_backward():
loss = closure()
- params, grads = [], []
+ params, grads, groups = [], [], []
for group in self.param_groups:
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
params.append(p)
grads.append(g)
+ groups.append(group)
if not params:
raise ValueError("No parameter has gradients")
@@ -1541,10 +1671,8 @@ def _double_backward_hvp(self, closure):
raise ExactHVPFailed(str(e.args))
unused = []
- for p, g, v, hv in zip(params, grads, vs, hvs):
- p.hessian_vector = detach(hv)
- p.grad = detach(g)
- p.vector = detach(v)
+ for group, p, g, v, hv in zip(groups, params, grads, vs, hvs):
+ set_temporary(group, p, grad=detach(g), vector=detach(v), hessian_vector=detach(hv))
if hv is None:
unused.append(list(p.shape))
@@ -1597,33 +1725,39 @@ def _handle_closure(self, closure):
self._fallback_enabled = True
return self._handle_closure(closure)
+ def _cleanup_temporary_tensors(self):
+ for group in self.param_groups:
+ tmp = group.pop("_tmp", None)
+ if tmp is None:
+ continue
+ for state in tmp.values():
+ fmt = state.get("restore_memory_format")
+ if fmt is None:
+ continue
+ tensor = state["tensor"]
+ self.mapping_inverse.pop(_tensor_key(tensor), None)
+ tensor.data = tensor.data.to(memory_format=fmt)
+ self.mapping_inverse[_tensor_key(tensor)] = (tensor, 0)
+
def step(self, closure: Optional[Callable] = None):
if self.precond_schedule is None:
self._is_preconditioning = False
else:
- self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
+ self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule)
loss = self._handle_closure(closure)
if self.use_ema:
self.ema_update()
# we assume that parameters are constant and that there are no excessive recompiles
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
- for group in self.param_groups:
- if "param_count" not in group:
- group["param_count"] = sum(p.numel() for p in group["params"])
- group["is_preconditioning"] = self._is_preconditioning
- self._step(group)
- for real, views in self.mapping.items():
- fmt = getattr(real, "_restore_memory_format", None)
- if fmt is not None:
- del self.mapping_inverse[_tensor_key(real)]
- real.data = real.data.to(memory_format=fmt)
- self.mapping_inverse[_tensor_key(real)] = (real, 0)
- del real._restore_memory_format
- for tensor in (real, *views):
- for key in ("grad", "vector", "hessian_vector", "orig"):
- if hasattr(tensor, key):
- setattr(tensor, key, None)
+ try:
+ for group in self.param_groups:
+ if "param_count" not in group:
+ group["param_count"] = sum(p.numel() for p in group["params"])
+ group["is_preconditioning"] = self._is_preconditioning
+ self._step(group)
+ finally:
+ self._cleanup_temporary_tensors()
return loss
@@ -2114,17 +2248,17 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
if source is None:
source = ref
- if ref.dtype != torch.bfloat16:
- return source.to(ref.dtype)
- if source.dtype == torch.bfloat16:
- return source
- if source.dtype in (torch.float16, torch.float32, torch.float64):
- source = source.to(torch.float32).view(dtype=torch.int32)
- noise = sum(torch.randint_like(source, low=0, high=(1 << 16)) for _ in range(dither_steps))
- noise = noise + source - (dither_steps - 1) * (1 << 15) # center | x - (N-1)*delta/2
- noise = noise.bitwise_and(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
- return noise.view(dtype=torch.float32).bfloat16()
- return source.to(ref.dtype)
+ dtype = torch.bfloat16
+ else:
+ dtype = ref.dtype
+ if dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
+ return source.to(dtype)
+
+ source = source.to(torch.float32).view(dtype=torch.int32)
+ noise = sum(torch.randint_like(source, low=0, high=(1 << 16)) for _ in range(dither_steps))
+ noise = noise + source - (dither_steps - 1) * (1 << 15) # center | x - (N-1)*delta/2
+ noise = noise.bitwise_and(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
+ return noise.view(dtype=torch.float32).bfloat16()
@decorator_knowngood
@@ -2151,7 +2285,7 @@ def _compilable_update_(
caution: bool,
g: List[Optional[Tensor]],
):
- for i, (u_, g_, p_) in enumerate(zip(u, g, p)): # lr is data-dependent -> can't compile a foreach
+ for i, (u_, g_, p_) in enumerate(zip(u, g, p)): # lr is data-dependent -> can't compile a multi-tensor op
u_ = promote(u_.view_as(p_))
p32_ = promote(p_)
if caution:
@@ -2175,9 +2309,10 @@ def update_param_(
_compilable_update_(param, update, decay, lr, caution, grad)
-def precond_schedule(step, precond_scheduler):
- precond_prob = max(step, 1) ** precond_scheduler[0]
- precond_prob = math.log10(precond_prob)
+@decorator_knowngood
+def precond_schedule(step: Tensor, precond_scheduler):
+ precond_prob = step.clamp(min=1) ** precond_scheduler[0]
+ precond_prob = torch.log10(precond_prob)
precond_prob = precond_prob ** precond_scheduler[1] + 1
return 1 / precond_prob
@@ -2363,7 +2498,7 @@ def init_Q_exprs(
@decorator_knowngood
def psgd_balance_Q(Q):
- norms = [promote(q.norm(float("inf"))).log() for q in Q]
+ norms = [promote(q.abs().max()).log() for q in Q]
geometric_mean = sum([n for n in norms]) / len(Q)
for q, n in zip(Q, norms):
q *= (geometric_mean - n).exp()
@@ -2524,10 +2659,10 @@ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
@decorator_knowngood
-def dampen_grad(g: Tensor, damp: float = 2**-13):
- # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
+def dampen_grad(g: Tensor, damp: float = 1e-9):
v = torch.randn_like(g)
- return v, g + damp * g.abs().mean() * v
+ damping = damp + torch.finfo(torch.float32).eps * g.abs()
+ return v, g + damping * v
@decorator_knowngood
@@ -2620,7 +2755,7 @@ def multi_flatten(*xs: Tuple[List[Tensor], int]):
@decorator_knowngood
-def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
+def dampen_multiple(g: List[Tensor], damp: float = 1e-9):
vs = []
gs = []
for g_ in g:
@@ -2631,8 +2766,7 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
- md = min_dtype(args)
- return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
+ return compiled_einsum(expr, *[promote(a) for a in args]).to(args[-1].dtype)
@decorator_knowngood
@@ -2671,14 +2805,13 @@ def psgd_calc_A_and_conjB(G: Tensor, Q, conjB: Tensor | None): # conjB ("V", "v
conjB = torch.randn_like(G)
exprA = cached_precond_grad_expr(ndim_tuple(Q), G.ndim) # calcA expr and cached precond expr are the same
A = casted_einsum(exprA, *Q, G)
- solve = torch.compiler.disable(torch.linalg.solve_triangular)
transposed_shape = original_shape = conjB.shape
prev_i = -1
qs, conjB = _psgd_calc_scalars_(Q, conjB)
for i, tri_q in qs:
conjB, transposed_shape = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, i)
prev_i = i
- conjB = solve(tri_q, conjB, upper=True, left=False)
+ conjB = no_compile_solve_triangular(tri_q, conjB, upper=True, left=False)
conjB, _ = _reshape_conjB(conjB, transposed_shape, original_shape, prev_i, -1)
return A, conjB
@@ -2687,7 +2820,7 @@ def max_singular_value_exact(A, use_lobpcg: bool = False):
try:
if use_lobpcg:
A = A @ A.T
- eigval, _ = torch.compiler.disable(torch.lobpcg)(A, k=1, largest=True)
+ eigval, _ = no_compile_lobpcg(A, k=1, largest=True)
return eigval[0].sqrt()
else:
return torch.linalg.svd(promote(A), driver="gesvdj")[1].max().to(A.dtype) # == linalg.matrix_norm(A, ord=2)
@@ -2727,7 +2860,7 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
Adapted from @evanatyourservice
"""
if max_abs is None:
- max_abs = A.norm(float("inf")).clamp(min=1e-8)
+ max_abs = A.abs().max().clamp(min=1e-8)
# cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size
@@ -2768,6 +2901,45 @@ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False,
return max_singular_value_power_iter(A, None, iterations=power_iter)
+@decorator_knowngood
+def max_eigenvalue_spd(A_outer: Tensor, power_iter: int = 4) -> Tensor:
+ """Power iteration for the largest eigenvalue of a symmetric positive (semi)definite matrix.
+ Exploits A = A^T: A^T A = A^2, so v -> A^T(Av) = v -> A(Av), saving a transpose.
+ Uses x @ A.mT (gemm transB=true) for faster BLAS dispatch than A.mv(x)."""
+ if A_outer.ndim < 2:
+ return A_outer.max()
+ x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
+ x_norm = promote(x_norm)
+
+ def _inner():
+ x = A_outer.index_select(0, max_idx).flatten().contiguous()
+ A = promote(A_outer) / x_norm
+ x = x / x_norm
+
+ def _mv(x):
+ return promote((x @ A.mT) @ A.mT)
+
+ for _ in range(power_iter):
+ x = F.normalize(_mv(x), dim=0)
+ return (x @ _mv(x)).sqrt() * x_norm
+
+ return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone()).squeeze()
+
+
+@decorator_knowngood
+def procrustes_step(Q: Tensor, max_step_size: float = 1 / 8) -> None:
+ Q_ = promote(Q)
+ R = (Q_.T - Q_).contiguous()
+ R_norm = max_singular_value(R, power_iter=2) + torch.finfo(R.dtype).smallest_normal
+ R = R / R_norm
+ RQ = R @ Q_
+ RRQ = R @ RQ
+ tr_RQ = RQ.diagonal().sum()
+ tr_RRQ = RRQ.diagonal().sum()
+ a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, max=max_step_size), max_step_size)
+ copy_stochastic_(Q, Q_ + a * (RQ + 0.5 * a * RRQ))
+
+
@decorator_knowngood
def clamped_max_singular_value(
A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
@@ -2821,10 +2993,10 @@ def _approx():
@decorator_knowngood
-def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
+def _balance_to_triu(Q: "TriuOrLine"):
if isinstance(Q[0], tuple):
psgd_balance_Q([o[1] for o in Q])
- return line_to_triu(Q, symmetric_output)
+ return line_to_triu(Q)
psgd_balance_Q(Q)
return Q
@@ -2927,33 +3099,20 @@ def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
return coeff0.float(), coeffs.float()
-@decorator_knowngood
-def _psgd_default_preconditioner_grad(
- terms: List[Tuple[Tensor, Tensor]],
- Q: List[Tensor],
-) -> List[Tensor]:
- out = []
- for q, (x, y) in zip(Q, terms):
- x = promote(x)
- y = promote(y)
- update = x - y
- if q.ndim < 2:
- update = promote(q) * update
- else:
- update = (promote(q) @ update).triu()
- out.append(update)
- return out
+def _update_lb(ell: Tensor, lb_state: Tensor, beta: Tensor) -> Tensor:
+ ell = promote(ell)
+ ell = ell.maximum(promote(lb_state) + (ell - promote(lb_state)) * (1 - beta))
+ copy_stochastic_(lb_state, ell)
+ return ell
-@decorator
+@decorator_no_fullgraph
def psgd_update_precond(
G: Tensor,
precond_lr: float,
oq: "TriuOrLine",
store_triu_as_line: bool,
- velocity: Optional[List[Tensor]],
beta2: float,
- ortho_method: Optional[str],
V: Tensor,
running_lower_bound: List[Tensor],
lower_bount_beta: float,
@@ -2965,15 +3124,61 @@ def psgd_update_precond(
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
A, conjB = psgd_calc_A_and_conjB(G, Q, V)
- terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
- del A, conjB, V
- updates = _psgd_default_preconditioner_grad(terms, Q)
- _psgd_precond_update_(
- updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
- )
+ del V
+
+ for oq_i, q, exprG, lb_state in zip(oq, Q, exprGs, running_lower_bound):
+ term1 = promote(compiled_einsum(exprG, A, A))
+ term2 = promote(compiled_einsum(exprG, conjB, conjB))
+
+ if q.ndim < 2:
+ ell = _update_lb((term1 + term2).max(), lb_state, lower_bount_beta)
+ update = promote(q) * (term1 - term2)
+ else:
+ ell = _update_lb(max_eigenvalue_spd(term1 + term2, power_iter=power_iter), lb_state, lower_bount_beta)
+ update = (term1 - term2).triu() @ promote(q)
+ if store_triu_as_line:
+ update = triu_to_line([update])[0][1]
+
+ real_oq = oq_i[1] if isinstance(oq_i, tuple) else oq_i
+ copy_stochastic_(real_oq, promote(real_oq) - update / ell * precond_lr)
return None
+@decorator_knowngood
+def psgd_pro_update_precond(
+ G: Tensor,
+ precond_lr: float,
+ Q: List[Tensor],
+ running_lower_bound: List[Tensor],
+ lower_bount_beta: float,
+ power_iter: int,
+ dampening: float,
+) -> None:
+ """Update Kronecker product preconditioner Q with Q0.5EQ1.5 (PRO) method."""
+ psgd_balance_Q(Q)
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
+ precond_lr, lower_bount_beta = scalar_guard(precond_lr, lower_bount_beta, G)
+
+ damping = dampening + torch.finfo(torch.float32).eps * G.abs()
+ Pg = psgd_precond_grad(G + damping * torch.randn_like(G), Q)
+
+ total_numel = G.numel()
+ for q, exprG, lb_state in zip(Q, exprGs, running_lower_bound):
+ term1 = compiled_einsum(exprG, Pg, Pg)
+ q_ = promote(q)
+
+ if q.ndim < 2:
+ term2 = total_numel / max(1, q.numel())
+ ell = _update_lb(term1.max() + term2, lb_state, lower_bount_beta)
+ copy_stochastic_(q, q_ - q_ * (term1 - term2) / ell * precond_lr)
+ else:
+ term2 = total_numel / q.shape[0]
+ ell = _update_lb(max_eigenvalue_spd(term1, power_iter=power_iter) + term2, lb_state, lower_bount_beta)
+ copy_stochastic_(q, q_ - (term1 @ q_ - term2 * q_) / ell * precond_lr)
+ procrustes_step(q)
+ del Pg
+
+
@decorator_knowngood
def bf16_matmul(x: Tensor, y: Tensor):
return (promote(x) @ promote(y)).to(x.dtype)
@@ -3037,8 +3242,8 @@ def eigvecs_product_rank1(
using the Householder reflector with first column v. Never materializes V.
Args:
- G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
- v: shape (d,) — current unit direction (top eigenvector of P).
+ G: shape (..., d) - gradient row(s) you want to rotate into eigenbasis.
+ v: shape (d,) - current unit direction (top eigenvector of P).
w: optional Householder vector w; pass to reuse across calls.
Returns:
@@ -3087,7 +3292,7 @@ def _psgd_precond_update_(
q = promote(oq)
if update.ndim < 2:
- lb = update.norm(float("inf"))
+ lb = update.abs().max()
else:
lb = max_singular_value(update, power_iter=power_iter)
update = promote(update)
@@ -3101,65 +3306,7 @@ def _psgd_precond_update_(
@decorator_knowngood
-def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
- """
- I: Identity
- U: Update / gg / target
- Q: q, preconditioner
- scale: scalar scale
- ---
- U = T * scale - I
- F = I - U # = 2I - U * scale
- O = F @ Q @ F - Q
- """
- out = []
- for gg, q in zip(GG, Q):
- if gg.ndim < 2:
- scale = max(1, gg.numel()) / numel
- target = promote(gg)
- update = target * scale - 1
- out.append(q - (1 - update) * q * (1 - update))
- else:
- scale = gg.size(0) / numel
- gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
- update = q - casted_einsum("ab,cd,bc", gg, gg, q)
- out.append(update + update.T) # make matrix symmetric
- return out
-
-
@decorator
-def inverse_free_psgd_update_precond(
- G: Tensor,
- precond_lr: float,
- oq: List[Tensor],
- store_triu_as_line: bool,
- velocity: Optional[List[Tensor]],
- beta2: float,
- ortho_method: Optional[str],
- V: None,
- running_lower_bound: List[Tensor],
- lower_bount_beta: float,
- power_iter: int,
-) -> Tensor:
- """Update Kronecker product preconditioner Q with pair (V, G)."""
- assert V is None
- assert ortho_method is None
- assert velocity is None
- del V, ortho_method, velocity
-
- Q = _balance_to_triu(oq, True)
- precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
- exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
-
- G = psgd_precond_grad(G, Q)
- terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
- matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
- _psgd_precond_update_(
- matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
- )
- return G
-
-
@decorator_knowngood
def _clip(x, norm, clip_at, eps=1e-8):
x32 = promote(x)
@@ -3420,15 +3567,13 @@ def triu_to_line(Q_list: List[Tensor]):
@decorator_knowngood
-def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
+def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
new = []
for shape, q in Q_list:
if shape is not None:
x, y = torch.triu_indices(*shape, device=q.device)
q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
q_mat[x, y] = q
- if symmetric_output:
- q_mat[y, x] = q
q = q_mat
new.append(q)
return new
@@ -3443,14 +3588,10 @@ def warn_once(msg):
_warned.add(msg)
-def psgd_should_update(
- group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
-):
+def psgd_should_update(group, prob: Union[float, callable], name: str = "cumulative_prob"):
group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
if not isinstance(prob, float):
prob = prob(group[f"{name}_prob_step"])
- if group["stochastic_schedule"]:
- return rng.random() < prob
cumulative_prob = group.get(name, 0)
group[name] = cumulative_prob + prob
return int(group[name]) > int(cumulative_prob)
@@ -3471,18 +3612,13 @@ def precond_grad_cached_(
cached_q: List[Tensor],
caution: bool = False,
grad: Optional[Tensor] = None,
- cast: bool = True,
):
if caution:
ea = _compilable_cautioning(grad, ea)
- md = min_dtype(list(cached_q) + [ea])
- args = [q.to(md) for q in cached_q]
- args = args + [ea.to(md)]
+ args = [promote(q) for q in cached_q]
+ args = args + [promote(ea)]
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
- new = compiled_einsum(expr, *args)
- if cast:
- return new.to(ea.dtype)
- return new
+ return compiled_einsum(expr, *args)
TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
@@ -3490,7 +3626,7 @@ def precond_grad_cached_(
@decorator_knowngood
def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
- precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
+ precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad)
update_param_(param, precond, lr, decay, caution=False)
@@ -3517,17 +3653,14 @@ def psgd_precond_grad(
caution: bool = False,
grad: Optional[Tensor] = None,
store_triu_as_line: bool = False,
- symmetric_output: bool = False,
):
if caution:
ea = _compilable_cautioning(grad, ea)
if store_triu_as_line:
- preconds = line_to_triu(preconds, symmetric_output)
- md = min_dtype(list(preconds) + [ea])
- args = [q.to(md) for q in preconds]
+ preconds = line_to_triu(preconds)
+ args = [promote(q) for q in preconds]
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
- new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
- return new.to(ea.dtype)
+ return compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea))
@decorator_knowngood
@@ -3540,7 +3673,6 @@ def _compilable_fused_psgd_precond_grad(
caution,
preconds: TriuOrLine,
store_triu_as_line: bool = False,
- symmetric_output: bool = False,
):
precond = psgd_precond_grad(
ea,
@@ -3548,7 +3680,6 @@ def _compilable_fused_psgd_precond_grad(
caution=caution,
grad=grad,
store_triu_as_line=store_triu_as_line,
- symmetric_output=symmetric_output,
)
update_param_(param, precond, lr, decay, caution=False, grad=grad)
@@ -3562,12 +3693,9 @@ def fused_psgd_precond_grad(
caution,
preconds: TriuOrLine,
store_triu_as_line: bool = False,
- symmetric_output: bool = False,
):
lr, decay = scalar_guard(lr, decay, param[0])
- _compilable_fused_psgd_precond_grad(
- ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
- )
+ _compilable_fused_psgd_precond_grad(ea, param, lr, grad, decay, caution, preconds, store_triu_as_line)
@decorator_knowngood
@@ -3619,6 +3747,41 @@ def caution(g, update):
return _compilable_cautioning(g, update)
+@decorator_knowngood
+def _compilable_hyperball_(
+ p: List[Tensor],
+ u: List[Tensor],
+ init_norm: List[Tensor],
+ lr: Tensor,
+ decay: float,
+ caution: bool,
+ g: List[Tensor],
+):
+ for op, u_, n_, g_ in zip(p, u, init_norm, g):
+ u_ = promote(u_.view_as(op))
+ p_ = promote(op)
+ if decay != 0:
+ u_ = u_ + p_ * decay
+ if caution:
+ u_ = _compilable_cautioning(promote(g_), u_)
+ u_norm = u_.norm()
+ u_norm = u_norm.clamp(min=1e-12)
+ u_ = u_ / u_norm
+ p_ = p_ - lr * u_ * n_
+ p_norm = p_.norm()
+ p_norm = p_norm.clamp(min=1e-12)
+ p_ = p_ * (n_ / p_norm)
+ copy_stochastic_(op, p_)
+
+
+def hyperball_step_(param, update, init_norm, lr, decay, caution, grad):
+ param, update, init_norm, grad = list_guard(param, update, init_norm, grad)
+ lr = scalar_guard(lr, param[0])
+ if not caution:
+ grad = [None] * len(param)
+ _compilable_hyperball_(param, update, init_norm, lr, decay, caution, grad)
+
+
def _inner_precond_update_prob_schedule(
n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
):
diff --git a/interactive/playground.py b/interactive/playground.py
deleted file mode 100644
index bc80f62..0000000
--- a/interactive/playground.py
+++ /dev/null
@@ -1,1596 +0,0 @@
-import functools
-from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional, Tuple
-
-import gradio as gr
-import numpy as np
-import plotly.graph_objects as go
-import torch
-import torch.nn as nn
-from plotly.subplots import make_subplots
-from sklearn.decomposition import PCA
-
-import heavyball
-import heavyball.chainable as C
-
-# TensorFlow Playground inspired colors
-COLORS = {
- "bg": "#ffffff",
- "surface": "#f7f7f7",
- "primary": "#ff6f00", # Orange
- "secondary": "#0d47a1", # Blue
- "positive": "#4caf50", # Green
- "negative": "#f44336", # Red
- "text": "#212121",
- "text_light": "#757575",
- "border": "#e0e0e0", # Component categories
- "gradient": "#9c27b0", # Purple - Gradient input
- "momentum": "#2196f3", # Blue - Momentum transforms
- "scaling": "#ff5722", # Deep Orange - Scaling transforms
- "regularization": "#009688", # Teal - Regularization
- "normalization": "#795548", # Brown - Normalization
- "update": "#4caf50", # Green - Update rules
-}
-
-
-# Problem base class
-class Problem(ABC):
- """Base class for optimization problems"""
-
- @property
- @abstractmethod
- def dim(self) -> int:
- """Dimension of the parameter space"""
- pass
-
- @abstractmethod
- def init(self) -> np.ndarray:
- """Initial parameters"""
- pass
-
- @abstractmethod
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- """Compute loss given parameters"""
- pass
-
- def bounds(self) -> Optional[List[Tuple[float, float]]]:
- """Parameter bounds for visualization (None if unbounded)"""
- return None
-
- def optimum(self) -> Optional[np.ndarray]:
- """Known optimal point if available"""
- return None
-
-
-class Function2D(Problem):
- """Wrapper for 2D test functions"""
-
- def __init__(self, func, bounds, init, optimal=None):
- self.func = func
- self._bounds = bounds
- self._init = init
- self._optimal = optimal
-
- @property
- def dim(self) -> int:
- return 2
-
- def init(self) -> np.ndarray:
- return np.array(self._init)
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- if x.dim() == 1:
- return self.func(x)
- else:
- # Handle batch of parameters
- return self.func(x[0])
-
- def bounds(self):
- return self._bounds
-
- def optimum(self):
- return np.array(self._optimal) if self._optimal else None
-
-
-class MLPProblem(Problem):
- """Train a small MLP to predict |x|"""
-
- def _data(self, n_samples: int = 128, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
- def __init__(self, hidden_size=4, n_samples=128):
- self.model = nn.Sequential(
- nn.Linear(1, hidden_size),
- nn.ReLU(),
- nn.Linear(hidden_size, hidden_size),
- nn.ReLU(),
- nn.Linear(hidden_size, 1),
- )
- self.x_data, self.y_data = self._data(n_samples)
- self._dim = sum(p.numel() for p in self.model.parameters())
-
- @property
- def dim(self) -> int:
- return self._dim
-
- def init(self) -> np.ndarray:
- # Initialize model and return flattened parameters
- torch.manual_seed(42)
- for p in self.model.parameters():
- if len(p.shape) >= 2:
- nn.init.xavier_uniform_(p)
- else:
- nn.init.zeros_(p)
- return self._flatten_params()
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- # Instead of modifying model parameters, we'll use functional approach
- # Split x into weight and bias tensors
- idx = 0
- params = []
- for p in self.model.parameters():
- numel = p.numel()
- param = x[idx : idx + numel].view(p.shape)
- params.append(param)
- idx += numel
-
- # Manually compute forward pass using the parameters from x
- # For a 2-layer MLP: Linear -> ReLU -> Linear
- w1, b1, w2, b2, w3, b3 = params
-
- h = torch.nn.functional.linear(self.x_data, w1, b1)
- h = torch.nn.functional.relu(h)
- pred = torch.nn.functional.linear(h, w2, b2)
- h = torch.nn.functional.relu(h)
- pred = torch.nn.functional.linear(h, w3, b3)
-
- return nn.functional.mse_loss(pred, self.y_data)
-
- def _flatten_params(self) -> np.ndarray:
- """Flatten model parameters to vector"""
- return torch.cat([p.data.view(-1) for p in self.model.parameters()]).numpy()
-
- def _unflatten_params(self, x: torch.Tensor):
- """Set model parameters from flat vector (used for evaluation, not training)"""
- idx = 0
- for p in self.model.parameters():
- numel = p.numel()
- p.data.copy_(x[idx : idx + numel].view(p.shape))
- idx += numel
-
-
-class AbsMLP(MLPProblem):
- def _data(self, n_samples: int = 128, seed: int = 42):
- x = torch.rand((n_samples, 1)) * 4 - 2
- return x, x.abs()
-
-
-class SineMLP(MLPProblem):
- def _data(self, n_samples: int = 128, seed: int = 42):
- x = torch.rand((n_samples, 1)) * 4 - 2
- return x, x.sin()
-
-
-class ExpMLP(MLPProblem):
- def _data(self, n_samples: int = 128, seed: int = 42):
- x = torch.rand((n_samples, 1)) * 4 - 2
- return x, x.exp()
-
-
-class QuadraticBowl(Problem):
- """N-dimensional quadratic bowl"""
-
- def __init__(self, dim=10):
- self._dim = dim
- self.center = np.random.randn(dim) * 2
-
- @property
- def dim(self) -> int:
- return self._dim
-
- def init(self) -> np.ndarray:
- return np.random.randn(self._dim) * 3
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- center = torch.tensor(self.center, dtype=x.dtype, device=x.device)
- return torch.sum((x - center) ** 2)
-
- def optimum(self):
- return self.center
-
-
-class StyblinskiTang(Problem):
- """Styblinski-Tang function (separable, multimodal)"""
-
- def __init__(self, dim=4):
- self._dim = dim
-
- @property
- def dim(self) -> int:
- return self._dim
-
- def init(self) -> np.ndarray:
- return np.random.uniform(-5, 5, self._dim)
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- return 0.5 * torch.sum(x**4 - 16 * x**2 + 5 * x)
-
- def bounds(self):
- return [(-5, 5)] * self._dim
-
- def optimum(self):
- # Global minimum at x_i = -2.903534 for all i
- return np.full(self._dim, -2.903534)
-
-
-# Test functions
-PROBLEMS = { # 2D Problems
- "Simple Bowl (2D)": Function2D(
- func=lambda x: (x[0] - 1) ** 2 + (x[1] - 2) ** 2,
- bounds=[(-3, 5), (-2, 6)],
- init=[-1.0, 0.0],
- optimal=[1.0, 2.0],
- ),
- "Ravine (2D)": Function2D(
- func=lambda x: (x[0] - 1) ** 2 + (x[1] - 2) ** 100,
- bounds=[(-3, 5), (-2, 6)],
- init=[-1.0, 1.0],
- optimal=[1.0, 2.0],
- ),
- "Rosenbrock (2D)": Function2D(
- func=lambda x: (1 - x[0]) ** 2 + 100 * (x[1] - x[0] ** 2) ** 2,
- bounds=[(-2, 2), (-1, 3)],
- init=[-1.0, 1.0],
- optimal=[1.0, 1.0],
- ),
- "Himmelblau (2D)": Function2D(
- func=lambda x: (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2,
- bounds=[(-5, 5), (-5, 5)],
- init=[0.0, 0.0],
- optimal=[3.0, 2.0],
- ),
- "Beale (2D)": Function2D(
- func=lambda x: (
- (1.5 - x[0] + x[0] * x[1]) ** 2
- + (2.25 - x[0] + x[0] * x[1] ** 2) ** 2
- + (2.625 - x[0] + x[0] * x[1] ** 3) ** 2
- ),
- bounds=[(-4.5, 4.5), (-4.5, 4.5)],
- init=[1.0, 1.0],
- optimal=[3.0, 0.5],
- ), # High-dimensional Problems
- "MLP abs(x) (13D)": AbsMLP(hidden_size=4, n_samples=128),
- "MLP sin(x) (13D)": SineMLP(hidden_size=4, n_samples=128),
- "MLP exp(x) (13D)": ExpMLP(hidden_size=4, n_samples=128),
- "Quadratic Bowl (10D)": QuadraticBowl(dim=10),
- "Styblinski-Tang (4D)": StyblinskiTang(dim=4),
- "Styblinski-Tang (8D)": StyblinskiTang(dim=8),
-}
-
-# Chainable optimizer components
-OPTIMIZER_BLOCKS = {
- "gradient": {
- "name": "🎯 Gradient Input",
- "color": COLORS["gradient"],
- "components": [
- {
- "id": "gradient_input",
- "name": "Raw Gradient",
- "icon": "∇",
- "description": "Start of the pipeline - raw gradients from backprop",
- "func": None, # This is just a placeholder
- "params": {},
- }
- ],
- },
- "momentum": {
- "name": "⚡ Momentum",
- "color": COLORS["momentum"],
- "components": [
- {
- "id": "heavyball",
- "name": "Heavy Ball",
- "icon": "🏐",
- "description": "Classic momentum: v = βv - g",
- "func": C.heavyball_momentum,
- "params": {"beta": 0.9},
- },
- {
- "id": "nesterov",
- "name": "Nesterov",
- "icon": "🚀",
- "description": "Look-ahead momentum",
- "func": C.nesterov_momentum,
- "params": {"beta": 0.9},
- },
- {
- "id": "nesterov_ema",
- "name": "Nesterov EMA",
- "icon": "📈",
- "description": "Exponential moving average variant",
- "func": C.nesterov_ema,
- "params": {"beta": 0.9},
- },
- {
- "id": "basic_ema",
- "name": "Basic EMA",
- "icon": "📉",
- "description": "Simple exponential moving average",
- "func": C.exp_avg,
- "params": {"betas": (0.9, 0.999)},
- },
- ],
- },
- "scaling": {
- "name": "📊 Adaptive Scaling",
- "color": COLORS["scaling"],
- "components": [
- {
- "id": "adam_scale",
- "name": "Adam Scaling",
- "icon": "👑",
- "description": "Adaptive per-parameter learning rates",
- "func": C.scale_by_adam,
- "params": {"betas": (0.9, 0.999), "eps": 1e-8},
- },
- {
- "id": "rmsprop_scale",
- "name": "RMSprop Scaling",
- "icon": "📉",
- "description": "Root mean square normalization",
- "func": C.scale_by_exp_avg_sq,
- "params": {"beta2": 0.999, "eps": 1e-8},
- },
- {
- "id": "adagrad_scale",
- "name": "AdaGrad Scaling",
- "icon": "📐",
- "description": "Accumulate all past gradients",
- "func": C.scale_by_exp_avg_sq,
- "params": {"beta2": 1.0, "eps": 1e-8},
- },
- ],
- },
- "regularization": {
- "name": "⚖️ Regularization",
- "color": COLORS["regularization"],
- "components": [
- {
- "id": "weight_decay",
- "name": "Weight Decay",
- "icon": "🪶",
- "description": "L2 regularization (AdamW style)",
- "func": C.weight_decay_to_ema,
- "params": {"weight_decay_to_ema": 0.01, "ema_beta": 0.999},
- },
- {
- "id": "weight_decay_init",
- "name": "Decay to Init",
- "icon": "🎯",
- "description": "Pull weights toward initialization",
- "func": C.weight_decay_to_init,
- "params": {"weight_decay_to_init": 0.01},
- },
- {
- "id": "l1_weight_decay",
- "name": "L1 Weight Decay",
- "icon": "⚡",
- "description": "L1 regularization to EMA",
- "func": C.l1_weight_decay_to_ema,
- "params": {"weight_decay_to_ema": 0.01, "ema_beta": 0.999},
- },
- ],
- },
- "normalization": {
- "name": "🔧 Gradient Processing",
- "color": COLORS["normalization"],
- "components": [
- {
- "id": "grad_clip",
- "name": "Gradient Clipping",
- "icon": "✂️",
- "description": "Clip gradient norm",
- "func": functools.partial(C.global_clip, clip_fn=heavyball.utils.l2_clip_),
- "params": {"max_norm": 1.0},
- },
- {
- "id": "sign",
- "name": "Sign SGD",
- "icon": "±",
- "description": "Use only gradient signs",
- "func": C.sign,
- "params": {"graft": True},
- },
- {
- "id": "orthogonalize",
- "name": "Orthogonalize",
- "icon": "⊥",
- "description": "Orthogonalize gradient to parameter",
- "func": C.orthogonalize_grad_to_param,
- "params": {"eps": 1e-8},
- },
- {
- "id": "orthogonalize_update",
- "name": "Orthogonalize Update",
- "icon": "⊗",
- "description": "Orthogonalize the update itself",
- "func": C.orthogonalize_update,
- "params": {},
- },
- ],
- },
- "advanced_scaling": {
- "name": "🚀 Advanced Optimizers",
- "color": COLORS["scaling"],
- "components": [
- {
- "id": "laprop",
- "name": "Laprop",
- "icon": "🌊",
- "description": "Layerwise adaptive propagation",
- "func": C.scale_by_laprop,
- "params": {"betas": (0.9, 0.999), "eps": 1e-8},
- },
- {
- "id": "adopt",
- "name": "ADOPT",
- "icon": "🎯",
- "description": "Adaptive gradient methods",
- "func": C.scale_by_adopt,
- "params": {"betas": (0.9, 0.9999), "eps": 1e-12},
- },
- ],
- },
- "preconditioning": {
- "name": "🔮 Preconditioning",
- "color": COLORS["gradient"],
- "components": [
- {
- "id": "soap",
- "name": "SOAP",
- "icon": "🧼",
- "description": "Shampoo-based preconditioning",
- "func": C.scale_by_soap,
- "params": {
- "shampoo_beta": 0.99,
- "max_precond_dim": 10000,
- "precondition_1d": False,
- "is_preconditioning": True,
- "betas": (0.9, 0.999),
- "eps": 1e-8,
- },
- },
- {
- "id": "psgd",
- "name": "PSGD",
- "icon": "🎲",
- "description": "Preconditioned SGD",
- "func": functools.partial(C.scale_by_psgd, cached=False),
- "params": {
- "precond_lr": 0.1,
- "max_size_triangular": 1024,
- "precondition_frequency": 10,
- "adaptive": False,
- "store_triu_as_line": True,
- "q_dtype": "float32",
- "inverse_free": False,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "min_ndim_triangular": 2,
- "memory_save_mode": None,
- "dampening": 1.0,
- "is_preconditioning": True,
- "betas": (0.9, 0.999),
- "eps": 1e-8,
- "ortho_method": "qr",
- "lower_bound_beta": 0.999,
- "precond_update_power_iterations": 1,
- },
- },
- {
- "id": "psgd_lra",
- "name": "PSGD LRA",
- "icon": "📐",
- "description": "Low-rank approximation PSGD",
- "func": C.scale_by_psgd_lra,
- "params": {
- "precond_lr": 0.1,
- "rank": 4,
- "param_count": 10000,
- "precondition_frequency": 10,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "q_dtype": "float32",
- "is_preconditioning": True,
- "eps": 1e-8,
- "betas": (0.9, 0.999),
- },
- },
- {
- "id": "delayed_psgd",
- "name": "Delayed PSGD",
- "icon": "⏱️",
- "description": "PSGD with delayed preconditioner updates",
- "func": functools.partial(C.scale_by_delayed_psgd, cached=False),
- "params": {
- "precond_lr": 0.1,
- "max_size_triangular": 1024,
- "precondition_frequency": 10,
- "adaptive": False,
- "store_triu_as_line": True,
- "q_dtype": "float32",
- "inverse_free": False,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "min_ndim_triangular": 2,
- "memory_save_mode": None,
- "dampening": 1.0,
- "is_preconditioning": True,
- "betas": (0.9, 0.999),
- "eps": 1e-8,
- "ortho_method": "qr",
- "lower_bound_beta": 0.999,
- "precond_update_power_iterations": 1,
- },
- },
- {
- "id": "delayed_psgd_lra",
- "name": "Delayed PSGD LRA",
- "icon": "⏰",
- "description": "Delayed low-rank PSGD",
- "func": C.scale_by_delayed_psgd_lra,
- "params": {
- "precond_lr": 0.1,
- "rank": 4,
- "param_count": 10000,
- "precondition_frequency": 10,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "q_dtype": "float32",
- "is_preconditioning": True,
- "eps": 1e-8,
- "betas": (0.9, 0.999),
- },
- },
- ],
- },
- "adaptive_lr": {
- "name": "🎛️ Adaptive Learning Rate",
- "color": COLORS["regularization"],
- "components": [
- {
- "id": "d_adapt",
- "name": "D-Adaptation",
- "icon": "📈",
- "description": "Automatic learning rate adaptation",
- "func": C.scale_by_d_adaptation,
- "params": {"initial_d": 1.0},
- },
- {
- "id": "lr_adapt",
- "name": "LR Adaptation",
- "icon": "🔄",
- "description": "Learning rate adaptation",
- "func": C.scale_by_lr_adaptation,
- "params": {"initial_d": 1.0, "lr_lr": 0.1},
- },
- {
- "id": "pointwise_lr_adapt",
- "name": "Pointwise LR",
- "icon": "🎚️",
- "description": "Per-parameter learning rate",
- "func": C.scale_by_pointwise_lr_adaptation,
- "params": {"initial_d": 1.0, "lr_lr": 0.1},
- },
- ],
- },
- "special": {
- "name": "✨ Special Methods",
- "color": COLORS["update"],
- "components": [
- {
- "id": "schedule_free",
- "name": "Schedule-Free",
- "icon": "🗓️",
- "description": "No learning rate schedule needed",
- "func": C.update_by_schedule_free,
- "params": {"r": 0.0, "weight_lr_power": 2.0},
- },
- {
- "id": "msam",
- "name": "MSAM",
- "icon": "🏔️",
- "description": "Momentum SAM optimizer",
- "func": C.update_by_msam,
- "params": {"sam_step_size": 0.05},
- },
- {
- "id": "mup",
- "name": "μP Approx",
- "icon": "📏",
- "description": "Maximal update parametrization",
- "func": C.mup_approx,
- "params": {},
- },
- {
- "id": "palm_beta2",
- "name": "PALM β₂",
- "icon": "🌴",
- "description": "Dynamic β₂ scheduling for PALM",
- "func": C.palm_beta2,
- "params": {"beta2_scale": 0.8},
- },
- {
- "id": "identity",
- "name": "Identity",
- "icon": "🔄",
- "description": "Pass-through (no operation)",
- "func": C.identity,
- "params": {},
- },
- ],
- },
-}
-
-# Pre-built optimizer recipes
-RECIPES = {
- "SGD": ["gradient_input"],
- "SGD + Momentum": ["gradient_input", "heavyball"],
- "Adam": ["gradient_input", "adam_scale"],
- "AdamW": ["gradient_input", "adam_scale", "weight_decay"],
- "RMSprop": ["gradient_input", "rmsprop_scale"],
- "Nesterov SGD": ["gradient_input", "nesterov"],
- "SOAP": ["gradient_input", "soap"],
- "Laprop": ["gradient_input", "laprop"],
- "ADOPT": ["gradient_input", "adopt"],
- "Sign SGD": ["gradient_input", "sign"],
- "AdamW + Clipping": ["gradient_input", "grad_clip", "adam_scale", "weight_decay"],
- "D-Adapted Adam": ["gradient_input", "adam_scale", "d_adapt"],
- "PSGD": ["gradient_input", "psgd"],
- "Schedule-Free AdamW": ["gradient_input", "adam_scale", "weight_decay", "schedule_free"],
- "EMA SGD": ["gradient_input", "basic_ema"],
- "Orthogonal Adam": ["gradient_input", "orthogonalize_update", "adam_scale"],
- "PALM": ["gradient_input", "palm_beta2", "adam_scale"],
- "Delayed PSGD": ["gradient_input", "delayed_psgd"],
- "μP Adam": ["gradient_input", "mup", "adam_scale"],
-}
-
-
-def get_component_info(comp_id):
- """Get component info by ID"""
- for category in OPTIMIZER_BLOCKS.values():
- for comp in category["components"]:
- if comp["id"] == comp_id:
- return comp, category["color"]
- return None, "#757575"
-
-
-def create_pipeline_display(pipeline):
- """Create HTML display for the current pipeline"""
- if not pipeline or pipeline == ["gradient_input"]:
- return """
-
- Drop components here to build your optimizer pipeline
-
- """
-
- blocks_html = ""
- for i, comp_id in enumerate(pipeline):
- comp_info, color = get_component_info(comp_id)
- if comp_info:
- show_arrow = i < len(pipeline) - 1
- blocks_html += f"""
-
-
-
{comp_info["icon"]}
-
{comp_info["name"]}
-
-
- {'
→
' if show_arrow else ""}
-
- """
-
- return f"""
-
- {blocks_html}
-
- """
-
-
-def build_optimizer_from_pipeline(pipeline: List[str], params, kwargs):
- """Build optimizer from pipeline of component IDs"""
- fns = []
- opt_params = {
- "lr": 0.001,
- "step": 1, # Required for many functions
- "caution": False,
- "weight_decay": 0.0,
- **kwargs,
- }
-
- for comp_id in pipeline:
- if comp_id == "gradient_input":
- continue # Skip the input block
-
- # Find component
- comp_info, _ = get_component_info(comp_id)
- if comp_info and comp_info["func"] is not None:
- fns.append(comp_info["func"])
- # Update parameters, handling special cases
- params_to_add = comp_info["params"].copy()
-
- # Handle special parameter mappings
- if "beta" in params_to_add and "betas" not in params_to_add:
- # Convert single beta to betas tuple for functions expecting it
- beta = params_to_add.pop("beta")
- if "betas" in opt_params:
- opt_params["betas"] = (beta, opt_params["betas"][1])
- else:
- opt_params["betas"] = (beta, 0.999)
-
- # First add component defaults, then override with user-provided kwargs
- for key, value in params_to_add.items():
- if key not in kwargs: # Only use default if not provided by user
- opt_params[key] = value
-
- if not fns:
- # Default to simple gradient descent
- return C.BaseOpt(params, opt_params, fns=[C.identity])
-
- return C.BaseOpt(params, opt_params, fns=fns)
-
-
-def run_optimization(problem_name: str, pipeline: List[str], steps: int, **kwargs) -> Tuple[Any, Dict]:
- """Run optimization with the custom pipeline"""
- problem = PROBLEMS[problem_name]
- kwargs["lr"] = 10 ** kwargs.get("lr", -2)
-
- # Initialize parameters
- init = problem.init()
- x = torch.nn.Parameter(torch.tensor(init, dtype=torch.float32))
- params = [x]
-
- # Build optimizer from pipeline
- optimizer = build_optimizer_from_pipeline(pipeline, params, kwargs)
-
- # Run optimization
- trajectory = []
- losses = []
- gradients = []
-
- for i in range(steps):
- trajectory.append(x.detach().numpy().copy())
-
- def closure():
- optimizer.zero_grad()
- loss = problem.loss(x)
- loss.backward()
- if x.grad is not None:
- gradients.append(x.grad.detach().numpy().copy())
- return loss
-
- loss = optimizer.step(closure)
- optimizer.zero_grad()
- losses.append(loss.item())
-
- trajectory = np.array(trajectory)
-
- # Create visualization
- fig = create_visualization(problem_name, trajectory, losses, gradients, pipeline)
-
- return fig, {
- "trajectory": trajectory.tolist(),
- "losses": losses,
- "final_loss": losses[-1],
- "steps_to_converge": find_convergence(losses),
- }
-
-
-def find_convergence(losses, threshold=1e-6):
- """Find when optimization converged"""
- if len(losses) < 10:
- return len(losses)
-
- for i in range(10, len(losses)):
- if abs(losses[i] - losses[i - 5]) < threshold:
- return i
- return len(losses)
-
-
-def create_visualization(problem_name, trajectory, losses, gradients, pipeline):
- """Create integrated visualization"""
- problem = PROBLEMS[problem_name]
-
- if problem.dim == 2:
- return create_2d_visualization(problem_name, trajectory, losses, gradients, pipeline)
- return create_highdim_visualization(problem_name, trajectory, losses, gradients, pipeline)
-
-
-def create_2d_visualization(problem_name, trajectory, losses, gradients, pipeline):
- """Create visualization for 2D problems"""
- problem = PROBLEMS[problem_name]
- bounds = problem.bounds()
-
- # Create subplots
- fig = make_subplots(
- rows=2,
- cols=2,
- subplot_titles=("Optimization Landscape", "Pipeline Architecture", "Loss Curve", "Learning Dynamics"),
- column_widths=[0.6, 0.4],
- row_heights=[0.6, 0.4],
- specs=[[{"type": "contour"}, {"type": "scatter"}], [{"type": "scatter"}, {"type": "scatter"}]],
- )
-
- # 1. Optimization landscape
- x = np.linspace(bounds[0][0], bounds[0][1], 100)
- y = np.linspace(bounds[1][0], bounds[1][1], 100)
- X, Y = np.meshgrid(x, y)
- Z = np.zeros_like(X)
-
- for i in range(X.shape[0]):
- for j in range(X.shape[1]):
- point = torch.tensor([X[i, j], Y[i, j]], dtype=torch.float32)
- Z[i, j] = problem.loss(point).item()
-
- # Contour plot
- fig.add_trace(
- go.Contour(
- x=x,
- y=y,
- z=Z,
- colorscale=[[0, "#e3f2fd"], [0.5, "#2196f3"], [1, "#0d47a1"]],
- showscale=False,
- contours=dict(
- start=0,
- end=Z.max(),
- size=Z.max() / 15,
- ),
- ),
- row=1,
- col=1,
- )
-
- # Add optimization path
- if len(trajectory) > 0:
- colors = np.linspace(0, 1, len(trajectory))
-
- for i in range(1, len(trajectory)):
- fig.add_trace(
- go.Scatter(
- x=[trajectory[i - 1, 0], trajectory[i, 0]],
- y=[trajectory[i - 1, 1], trajectory[i, 1]],
- mode="lines",
- line=dict(color=f"rgba(255, {int(111 * (1 - colors[i]))}, 0, {0.3 + 0.7 * colors[i]})", width=3),
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=1,
- )
-
- # Start and end points
- fig.add_trace(
- go.Scatter(
- x=[trajectory[0, 0]],
- y=[trajectory[0, 1]],
- mode="markers+text",
- marker=dict(size=12, color="#4caf50", line=dict(color="white", width=2)),
- text=["Start"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- fig.add_trace(
- go.Scatter(
- x=[trajectory[-1, 0]],
- y=[trajectory[-1, 1]],
- mode="markers+text",
- marker=dict(size=12, color="#ff6f00", line=dict(color="white", width=2)),
- text=["End"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- # 2. Pipeline visualization
- block_x = []
- block_y = []
- block_text = []
- block_colors = []
-
- for i, comp_id in enumerate(pipeline):
- block_x.append(i / (len(pipeline) - 1) if len(pipeline) > 1 else 0.5)
- block_y.append(0.5)
-
- comp_info, comp_color = get_component_info(comp_id)
- block_text.append(comp_info["icon"] if comp_info else "?")
- block_colors.append(comp_color)
-
- fig.add_trace(
- go.Scatter(
- x=block_x,
- y=block_y,
- mode="markers+text",
- marker=dict(size=40, color=block_colors, line=dict(color="white", width=2)),
- text=block_text,
- textposition="middle center",
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=2,
- )
-
- # Add arrows between blocks
- for i in range(len(pipeline) - 1):
- fig.add_annotation(
- x=block_x[i + 1],
- y=0.5,
- ax=block_x[i],
- ay=0.5,
- xref="x2",
- yref="y2",
- axref="x2",
- ayref="y2",
- showarrow=True,
- arrowhead=2,
- arrowsize=1,
- arrowwidth=2,
- arrowcolor="#ff6f00",
- row=1,
- col=2,
- )
-
- # 3. Loss curve
- fig.add_trace(
- go.Scatter(
- x=list(range(len(losses))),
- y=losses,
- mode="lines",
- line=dict(color="#ff6f00", width=3),
- fill="tozeroy",
- fillcolor="rgba(255, 111, 0, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=1,
- )
-
- # 4. Learning dynamics (gradient norm)
- if gradients:
- grad_norms = [np.linalg.norm(g) for g in gradients]
- fig.add_trace(
- go.Scatter(
- x=list(range(len(grad_norms))),
- y=grad_norms,
- mode="lines",
- line=dict(color="#2196f3", width=2),
- fill="tozeroy",
- fillcolor="rgba(33, 150, 243, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=2,
- )
-
- # Update layout
- fig.update_layout(
- height=700,
- showlegend=False,
- paper_bgcolor="white",
- plot_bgcolor="white",
- margin=dict(l=40, r=40, t=60, b=40),
- )
-
- # Update axes
- fig.update_xaxes(showgrid=True, gridcolor="#f0f0f0")
- fig.update_yaxes(showgrid=True, gridcolor="#f0f0f0")
-
- # Pipeline plot
- fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=2)
- fig.update_yaxes(showticklabels=False, showgrid=False, range=[0, 1], row=1, col=2)
-
- # Loss plot
- fig.update_xaxes(title_text="Iteration", row=2, col=1)
- fig.update_yaxes(title_text="Loss", type="log", row=2, col=1)
-
- # Gradient plot
- fig.update_xaxes(title_text="Iteration", row=2, col=2)
- fig.update_yaxes(title_text="Gradient Norm", row=2, col=2)
-
- return fig
-
-
-def create_highdim_visualization(problem_name, trajectory, losses, gradients, pipeline):
- """Create visualization for high-dimensional problems using PCA"""
- problem = PROBLEMS[problem_name]
-
- # Create subplots
- fig = make_subplots(
- rows=2,
- cols=2,
- subplot_titles=("PCA Trajectory Projection", "Pipeline Architecture", "Loss Curve", "Learning Dynamics"),
- column_widths=[0.6, 0.4],
- row_heights=[0.6, 0.4],
- specs=[[{"type": "contour"}, {"type": "scatter"}], [{"type": "scatter"}, {"type": "scatter"}]],
- )
-
- # 1. PCA projection of trajectory
- if len(trajectory) > 2:
- # Fit PCA on trajectory
- trajectory_array = np.array(trajectory)
- pca = PCA(n_components=min(2, trajectory_array.shape[1]))
- trajectory_2d = pca.fit_transform(trajectory_array)
-
- # Create loss landscape in PCA space
- # Determine bounds based on trajectory
- margin = 0.2
- x_min, x_max = trajectory_2d[:, 0].min(), trajectory_2d[:, 0].max()
- x_range = x_max - x_min
- x_min -= margin * x_range
- x_max += margin * x_range
-
- if trajectory_2d.shape[1] > 1:
- y_min, y_max = trajectory_2d[:, 1].min(), trajectory_2d[:, 1].max()
- y_range = y_max - y_min
- y_min -= margin * y_range
- y_max += margin * y_range
- else:
- y_min, y_max = -1, 1
-
- # Create grid in PCA space
- n_points = 50
- x_grid = np.linspace(x_min, x_max, n_points)
- y_grid = np.linspace(y_min, y_max, n_points)
- X_grid, Y_grid = np.meshgrid(x_grid, y_grid)
-
- # Compute losses on grid
- Z_grid = np.zeros_like(X_grid)
- mean_trajectory = pca.mean_
-
- for i in range(n_points):
- for j in range(n_points):
- # Map 2D point back to high-dimensional space
- pca_coords = np.array([X_grid[i, j], Y_grid[i, j]])
- if trajectory_2d.shape[1] == 1:
- pca_coords = pca_coords[:1] # Use only first coordinate
-
- # Reconstruct high-dimensional point
- high_dim_point = mean_trajectory + pca_coords @ pca.components_[: len(pca_coords)]
-
- # Evaluate loss
- x_tensor = torch.tensor(high_dim_point, dtype=torch.float32)
- Z_grid[i, j] = problem.loss(x_tensor).item()
-
- # Add contour plot
- fig.add_trace(
- go.Contour(
- x=x_grid,
- y=y_grid,
- z=Z_grid,
- colorscale=[[0, "#e3f2fd"], [0.5, "#2196f3"], [1, "#0d47a1"]],
- showscale=False,
- contours=dict(
- start=Z_grid.min(),
- end=Z_grid.max(),
- size=(Z_grid.max() - Z_grid.min()) / 15,
- ),
- ),
- row=1,
- col=1,
- )
-
- # Add trajectory on top of contour
- colors = np.linspace(0, 1, len(trajectory_2d))
-
- for i in range(1, len(trajectory_2d)):
- fig.add_trace(
- go.Scatter(
- x=[trajectory_2d[i - 1, 0], trajectory_2d[i, 0]],
- y=[trajectory_2d[i - 1, 1] if trajectory_2d.shape[1] > 1 else [0, 0]],
- mode="lines",
- line=dict(color=f"rgba(255, {int(111 * (1 - colors[i]))}, 0, {0.3 + 0.7 * colors[i]})", width=3),
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=1,
- )
-
- # Start and end points
- fig.add_trace(
- go.Scatter(
- x=[trajectory_2d[0, 0]],
- y=[trajectory_2d[0, 1] if trajectory_2d.shape[1] > 1 else 0],
- mode="markers+text",
- marker=dict(size=12, color="#4caf50", line=dict(color="white", width=2)),
- text=["Start"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- fig.add_trace(
- go.Scatter(
- x=[trajectory_2d[-1, 0]],
- y=[trajectory_2d[-1, 1] if trajectory_2d.shape[1] > 1 else 0],
- mode="markers+text",
- marker=dict(size=12, color="#ff6f00", line=dict(color="white", width=2)),
- text=["End"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- # Add explained variance text
- if trajectory_2d.shape[1] > 1:
- explained_var = pca.explained_variance_ratio_
- fig.add_annotation(
- x=0.5,
- y=1.1,
- xref="x domain",
- yref="y domain",
- text=f"Explained variance: PC1={explained_var[0]:.1%}, PC2={explained_var[1]:.1%}",
- showarrow=False,
- row=1,
- col=1,
- )
-
- # 2. Pipeline visualization (same as 2D)
- block_x = []
- block_y = []
- block_text = []
- block_colors = []
-
- for i, comp_id in enumerate(pipeline):
- block_x.append(i / (len(pipeline) - 1) if len(pipeline) > 1 else 0.5)
- block_y.append(0.5)
-
- comp_info, comp_color = get_component_info(comp_id)
- block_text.append(comp_info["icon"] if comp_info else "?")
- block_colors.append(comp_color)
-
- fig.add_trace(
- go.Scatter(
- x=block_x,
- y=block_y,
- mode="markers+text",
- marker=dict(size=40, color=block_colors, line=dict(color="white", width=2)),
- text=block_text,
- textposition="middle center",
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=2,
- )
-
- # Add arrows between blocks
- for i in range(len(pipeline) - 1):
- fig.add_annotation(
- x=block_x[i + 1],
- y=0.5,
- ax=block_x[i],
- ay=0.5,
- xref="x2",
- yref="y2",
- axref="x2",
- ayref="y2",
- showarrow=True,
- arrowhead=2,
- arrowsize=1,
- arrowwidth=2,
- arrowcolor="#ff6f00",
- row=1,
- col=2,
- )
-
- # 3. Loss curve
- fig.add_trace(
- go.Scatter(
- x=list(range(len(losses))),
- y=losses,
- mode="lines",
- line=dict(color="#ff6f00", width=3),
- fill="tozeroy",
- fillcolor="rgba(255, 111, 0, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=1,
- )
-
- # 4. Learning dynamics (gradient norm)
- if gradients:
- grad_norms = [np.linalg.norm(g) for g in gradients]
- fig.add_trace(
- go.Scatter(
- x=list(range(len(grad_norms))),
- y=grad_norms,
- mode="lines",
- line=dict(color="#2196f3", width=2),
- fill="tozeroy",
- fillcolor="rgba(33, 150, 243, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=2,
- )
-
- # Update layout
- fig.update_layout(
- height=700,
- showlegend=False,
- paper_bgcolor="white",
- plot_bgcolor="white",
- margin=dict(l=40, r=40, t=60, b=40),
- )
-
- # Update axes
- fig.update_xaxes(showgrid=True, gridcolor="#f0f0f0")
- fig.update_yaxes(showgrid=True, gridcolor="#f0f0f0")
-
- # PCA plot
- fig.update_xaxes(title_text="First Principal Component", row=1, col=1)
- fig.update_yaxes(title_text="Second Principal Component", row=1, col=1)
-
- # Pipeline plot
- fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=2)
- fig.update_yaxes(showticklabels=False, showgrid=False, range=[0, 1], row=1, col=2)
-
- # Loss plot
- fig.update_xaxes(title_text="Iteration", row=2, col=1)
- fig.update_yaxes(title_text="Loss", type="log", row=2, col=1)
-
- # Gradient plot
- fig.update_xaxes(title_text="Iteration", row=2, col=2)
- fig.update_yaxes(title_text="Gradient Norm", row=2, col=2)
-
- return fig
-
-
-def create_app():
- with gr.Blocks(theme=gr.themes.Base()) as app:
- # Custom CSS
- gr.HTML("""
-
- """)
-
- # Header
- gr.Markdown("""
- # 🧩 HeavyBall Chainable Optimizer Playground
-
- ### Build custom optimizers by combining components like LEGO blocks!
-
- Click components to add them to your pipeline. Each component transforms the gradient in a specific way - stack them to create powerful optimization algorithms!
- """)
-
- # Hidden state for pipeline
- pipeline_state = gr.State(["gradient_input"])
-
- with gr.Row():
- # Left column - Component palette
- with gr.Column(scale=1):
- gr.Markdown("### 🎨 Component Palette")
- gr.Markdown("*Click blocks to add to pipeline*")
-
- # Component buttons organized by category
- for cat_id, category in OPTIMIZER_BLOCKS.items():
- if cat_id == "gradient": # Skip gradient input in palette
- continue
-
- gr.HTML(f'')
-
- with gr.Row():
- for comp in category["components"]:
- btn = gr.Button(
- value=f"{comp['icon']} {comp['name']}", elem_id=f"btn_{comp['id']}", size="sm"
- )
- # Store component ID in button for click handler
- btn.click(
- fn=lambda p, cid=comp["id"]: p + [cid],
- inputs=[pipeline_state],
- outputs=[pipeline_state],
- )
-
- # Recipe selector
- gr.Markdown("### 📚 Preset Recipes")
- recipe_dropdown = gr.Dropdown(choices=list(RECIPES.keys()), value=None, label="Load a preset optimizer")
- load_recipe_btn = gr.Button("Load Recipe", size="sm")
-
- # Center column - Main visualization
- with gr.Column(scale=2):
- # Pipeline builder
- gr.Markdown("### 🔧 Pipeline Builder")
- pipeline_display = gr.HTML()
-
- with gr.Row():
- clear_pipeline_btn = gr.Button("🗑️ Clear Pipeline", size="sm", variant="secondary")
- refresh_btn = gr.Button("🔄 Refresh Display", size="sm")
-
- # Visualization
- viz_plot = gr.Plot(label="")
-
- # Run button
- run_btn = gr.Button("🚀 Run Optimization", variant="primary", size="lg")
-
- # Right column - Parameters and metrics
- with gr.Column(scale=1):
- gr.Markdown("### ⚙️ Parameters")
-
- problem_select = gr.Dropdown(choices=list(PROBLEMS.keys()), value="Rosenbrock (2D)", label="Problem")
-
- lr_slider = gr.Slider(minimum=-4, maximum=-1, value=-2, step=0.1, label="Learning Rate (10^x)")
-
- steps_slider = gr.Slider(minimum=10, maximum=500, value=200, step=10, label="Steps")
-
- # Component-specific parameters
- with gr.Accordion("Advanced Parameters", open=False):
- beta_slider = gr.Slider(minimum=0.0, maximum=0.999, value=0.9, step=0.001, label="Momentum β")
- beta2_slider = gr.Slider(minimum=0.0, maximum=0.999, value=0.999, step=0.001, label="Adam β₂")
- eps_slider = gr.Slider(minimum=1e-8, maximum=1e-4, value=1e-8, step=1e-8, label="Epsilon")
-
- # Weight decay parameters
- weight_decay_slider = gr.Slider(
- minimum=0.0, maximum=0.1, value=0.01, step=0.001, label="Weight Decay"
- )
- ema_beta_slider = gr.Slider(
- minimum=0.9, maximum=0.999, value=0.999, step=0.001, label="EMA β (for weight decay)"
- )
-
- # Gradient clipping
- max_norm_slider = gr.Slider(
- minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Gradient Clip Norm"
- )
-
- # Preconditioning parameters
- shampoo_beta_slider = gr.Slider(
- minimum=0.9, maximum=0.999, value=0.99, step=0.001, label="Shampoo β"
- )
- precond_lr_slider = gr.Slider(
- minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Preconditioner LR"
- )
- precondition_frequency_slider = gr.Slider(
- minimum=1, maximum=100, value=10, step=1, label="Precondition Frequency"
- )
- rank_slider = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Low-rank Approximation Rank")
- param_count_slider = gr.Slider(
- minimum=1000, maximum=100000, value=10000, step=1000, label="Param Count (PSGD LRA)"
- )
-
- # Adaptive LR parameters
- initial_d_slider = gr.Slider(
- minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Initial D (D-adaptation)"
- )
- lr_lr_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Learning Rate LR")
-
- # Special method parameters
- r_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Schedule-Free r")
- weight_lr_power_slider = gr.Slider(
- minimum=0.0, maximum=4.0, value=2.0, step=0.1, label="Weight LR Power"
- )
- sam_step_size_slider = gr.Slider(
- minimum=0.01, maximum=0.5, value=0.05, step=0.01, label="SAM Step Size"
- )
- beta2_scale_slider = gr.Slider(
- minimum=0.5, maximum=1.0, value=0.8, step=0.01, label="PALM β₂ Scale"
- )
-
- # Sign SGD parameters
- graft_checkbox = gr.Checkbox(value=True, label="Graft (Sign SGD)")
-
- gr.Markdown("### 📊 Metrics")
-
- final_loss_display = gr.Textbox(label="Final Loss", value="-")
- convergence_display = gr.Textbox(label="Steps to Converge", value="-")
-
- # Footer
- gr.Markdown("""
- ---
- ### 💡 How it works:
-
- 1. **Start with Gradient** - Every pipeline begins with raw gradients
- 2. **Add Transforms** - Click components to add them to your pipeline
- 3. **Order Matters** - Components are applied in sequence
- 4. **Run & Compare** - See how different combinations perform!
-
- **Example combinations:**
- - Adam = Gradient → Adam Scaling
- - AdamW = Gradient → Adam Scaling → Weight Decay
- - Momentum SGD = Gradient → Heavy Ball
- """)
-
- # Event handlers
- def update_display(pipeline):
- return create_pipeline_display(pipeline)
-
- def load_recipe(recipe_name):
- if recipe_name in RECIPES:
- return RECIPES[recipe_name].copy()
- return ["gradient_input"]
-
- def clear_pipeline():
- return ["gradient_input"]
-
- def remove_block(pipeline, index):
- """Remove a block from the pipeline"""
- if 0 <= index < len(pipeline) and pipeline[index] != "gradient_input":
- new_pipeline = pipeline.copy()
- new_pipeline.pop(index)
- return new_pipeline
- return pipeline
-
- def run_optimization_handler(
- problem,
- pipeline,
- steps,
- lr,
- beta,
- beta2,
- eps,
- weight_decay,
- ema_beta,
- max_norm,
- shampoo_beta,
- precond_lr,
- precondition_frequency,
- rank,
- param_count,
- initial_d,
- lr_lr,
- r,
- weight_lr_power,
- sam_step_size,
- beta2_scale,
- graft,
- ):
- """Run optimization with current pipeline"""
- if not pipeline or len(pipeline) == 1:
- pipeline = ["gradient_input"]
-
- # Run optimization with all parameters
- fig, metrics = run_optimization(
- problem,
- pipeline,
- steps,
- lr=lr,
- beta=beta,
- betas=(beta, beta2),
- beta2=beta2,
- eps=eps,
- weight_decay_to_ema=weight_decay,
- weight_decay_to_init=weight_decay,
- ema_beta=ema_beta,
- max_norm=max_norm,
- shampoo_beta=shampoo_beta,
- precond_lr=precond_lr,
- precondition_frequency=precondition_frequency,
- rank=rank,
- param_count=param_count,
- initial_d=initial_d,
- lr_lr=lr_lr,
- r=r,
- weight_lr_power=weight_lr_power,
- sam_step_size=sam_step_size,
- beta2_scale=beta2_scale,
- graft=graft,
- )
-
- return fig, f"{metrics['final_loss']:.2e}", str(metrics["steps_to_converge"])
-
- # Connect events
- # Update display when pipeline changes
- pipeline_state.change(fn=update_display, inputs=[pipeline_state], outputs=[pipeline_display])
-
- # Recipe loading
- load_recipe_btn.click(fn=load_recipe, inputs=[recipe_dropdown], outputs=[pipeline_state])
-
- # Clear pipeline
- clear_pipeline_btn.click(fn=clear_pipeline, outputs=[pipeline_state])
-
- # Refresh display
- refresh_btn.click(fn=update_display, inputs=[pipeline_state], outputs=[pipeline_display])
-
- # Run optimization
- run_btn.click(
- fn=run_optimization_handler,
- inputs=[
- problem_select,
- pipeline_state,
- steps_slider,
- lr_slider,
- beta_slider,
- beta2_slider,
- eps_slider,
- weight_decay_slider,
- ema_beta_slider,
- max_norm_slider,
- shampoo_beta_slider,
- precond_lr_slider,
- precondition_frequency_slider,
- rank_slider,
- param_count_slider,
- initial_d_slider,
- lr_lr_slider,
- r_slider,
- weight_lr_power_slider,
- sam_step_size_slider,
- beta2_scale_slider,
- graft_checkbox,
- ],
- outputs=[viz_plot, final_loss_display, convergence_display],
- )
-
- # Add JavaScript for removing blocks
- app.load(
- None,
- None,
- None,
- js="""
- function() {
- // Function to remove pipeline blocks
- window.removePipelineBlock = function(index) {
- // This would need to trigger a Gradio event
- console.log('Remove block at index:', index);
- // In a real implementation, this would update the pipeline state
- };
-
- console.log('Playground initialized');
- }
- """,
- )
-
- # Initialize display
- app.load(
- fn=lambda: (
- create_pipeline_display(["gradient_input"]),
- run_optimization("Rosenbrock (2D)", ["gradient_input", "adam_scale"], 200, lr=-2)[0],
- ),
- outputs=[pipeline_display, viz_plot],
- )
-
- return app
-
-
-if __name__ == "__main__":
- app = create_app()
- app.launch(share=False)
diff --git a/interactive/static/init-globals.js b/interactive/static/init-globals.js
deleted file mode 100644
index b820bc9..0000000
--- a/interactive/static/init-globals.js
+++ /dev/null
@@ -1,2 +0,0 @@
-// This file initializes global variables for the node editor
-// It will be populated by Python when the page loads
diff --git a/interactive/static/node-editor.js b/interactive/static/node-editor.js
deleted file mode 100644
index 525597f..0000000
--- a/interactive/static/node-editor.js
+++ /dev/null
@@ -1,570 +0,0 @@
-console.log('Node editor script starting execution...');
-
-class NodeEditor {
- constructor(containerId) {
- this.container = document.getElementById(containerId);
- this.canvas = document.getElementById('canvas');
- this.palette = document.getElementById('palette');
- this.inspector = document.getElementById('inspector');
- this.nodes = [];
- this.connections = [];
- this.selectedNode = null;
- this.draggingNode = null;
- this.connecting = null;
- this.nodeIdCounter = 0;
- this.offset = { x: 0, y: 0 };
-
- this.init();
- }
-
- init() {
- // Set up SVG for connections
- this.svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg');
- this.svg.style.position = 'absolute';
- this.svg.style.width = '100%';
- this.svg.style.height = '100%';
- this.svg.style.pointerEvents = 'none';
- this.svg.style.zIndex = '1';
- this.svg.style.top = '0';
- this.svg.style.left = '0';
- this.canvas.appendChild(this.svg);
-
- // Set up event listeners
- this.setupPalette();
- this.setupCanvas();
-
- // Add gradient input node by default
- this.addNode('gradient_input', 100, 300);
- }
-
- setupPalette() {
- // Make palette items draggable
- const items = this.palette.querySelectorAll('.palette-item');
- items.forEach(item => {
- item.draggable = true;
- item.addEventListener('dragstart', (e) => this.onPaletteDragStart(e));
- item.addEventListener('dragend', (e) => this.onPaletteDragEnd(e));
- });
- }
-
- setupCanvas() {
- this.canvas.addEventListener('dragover', (e) => e.preventDefault());
- this.canvas.addEventListener('drop', (e) => this.onCanvasDrop(e));
- this.canvas.addEventListener('mousedown', (e) => this.onCanvasMouseDown(e));
- this.canvas.addEventListener('mousemove', (e) => this.onCanvasMouseMove(e));
- this.canvas.addEventListener('mouseup', (e) => this.onCanvasMouseUp(e));
- }
-
- onPaletteDragStart(e) {
- e.target.classList.add('dragging');
- e.dataTransfer.effectAllowed = 'copy';
- e.dataTransfer.setData('nodeType', e.target.dataset.type);
- }
-
- onPaletteDragEnd(e) {
- e.target.classList.remove('dragging');
- }
-
- onCanvasDrop(e) {
- e.preventDefault();
- const nodeType = e.dataTransfer.getData('nodeType');
- if (nodeType) {
- const rect = this.canvas.getBoundingClientRect();
- const x = e.clientX - rect.left - 80;
- const y = e.clientY - rect.top - 30;
- this.addNode(nodeType, x, y);
- this.updateCodePreview();
- }
- }
-
- addNode(type, x, y) {
- console.log(`Adding node: type=${type}, x=${x}, y=${y}`);
- const nodeId = `node-${this.nodeIdCounter++}`;
- const nodeData = window.CHAINABLE_FUNCTIONS[type] || {
- description: type === 'gradient_input' ? 'Gradient Input' : 'Unknown',
- color: '#666',
- inputs: type === 'gradient_input' ? [] : ['grad'],
- outputs: type === 'gradient_input' ? ['grad'] : []
- };
-
- const node = document.createElement('div');
- node.className = 'node';
- node.id = nodeId;
- node.style.left = x + 'px';
- node.style.top = y + 'px';
- node.style.setProperty('--node-color', nodeData.color);
- node.style.borderLeftColor = nodeData.color;
- node.style.borderLeftWidth = '4px';
- console.log('Created node element:', node);
-
- // Special styling for gradient input
- if (type === 'gradient_input') {
- node.classList.add('gradient-input');
- }
-
- node.innerHTML = `
- ${type === 'gradient_input' ? 'Gradient Input' : type.replace(/_/g, ' ')}
- ${nodeData.description}
- ${nodeData.inputs.length ? '' : ''}
- ${nodeData.outputs.length ? '' : ''}
- `;
-
- console.log('Canvas element:', this.canvas);
- console.log('Canvas dimensions:', this.canvas.offsetWidth, 'x', this.canvas.offsetHeight);
- this.canvas.appendChild(node);
- console.log('Node appended to canvas. Total nodes in canvas:', this.canvas.querySelectorAll('.node').length);
-
- // Set up node interactions
- node.addEventListener('mousedown', (e) => this.onNodeMouseDown(e, nodeId));
-
- // Set up port interactions
- const inputPort = node.querySelector('.node-port-input');
- const outputPort = node.querySelector('.node-port-output');
-
- if (inputPort) {
- inputPort.addEventListener('mousedown', (e) => this.onPortMouseDown(e, nodeId, 'input'));
- }
- if (outputPort) {
- outputPort.addEventListener('mousedown', (e) => this.onPortMouseDown(e, nodeId, 'output'));
- }
-
- // Store node data
- this.nodes.push({
- id: nodeId,
- type: type,
- element: node,
- x: x,
- y: y,
- params: nodeData.params || {}
- });
-
- return nodeId;
- }
-
- onNodeMouseDown(e, nodeId) {
- if (e.target.classList.contains('node-port')) return;
-
- e.preventDefault();
- this.selectNode(nodeId);
-
- const node = this.nodes.find(n => n.id === nodeId);
- this.draggingNode = node;
-
- const rect = node.element.getBoundingClientRect();
- const canvasRect = this.canvas.getBoundingClientRect();
- this.offset = {
- x: e.clientX - rect.left + canvasRect.left,
- y: e.clientY - rect.top + canvasRect.top
- };
- }
-
- onPortMouseDown(e, nodeId, portType) {
- e.preventDefault();
- e.stopPropagation();
-
- const port = e.target;
- const rect = port.getBoundingClientRect();
- const canvasRect = this.canvas.getBoundingClientRect();
-
- this.connecting = {
- nodeId: nodeId,
- portType: portType,
- startX: rect.left + rect.width / 2 - canvasRect.left,
- startY: rect.top + rect.height / 2 - canvasRect.top
- };
-
- // Create preview connection
- const line = document.createElementNS('http://www.w3.org/2000/svg', 'path');
- line.classList.add('connection-preview');
- line.id = 'preview-connection';
- this.svg.appendChild(line);
- }
-
- onCanvasMouseMove(e) {
- const rect = this.canvas.getBoundingClientRect();
- const x = e.clientX - rect.left;
- const y = e.clientY - rect.top;
-
- if (this.draggingNode) {
- this.draggingNode.element.style.left = (x - this.offset.x) + 'px';
- this.draggingNode.element.style.top = (y - this.offset.y) + 'px';
- this.draggingNode.x = x - this.offset.x;
- this.draggingNode.y = y - this.offset.y;
- this.updateConnections();
- }
-
- if (this.connecting) {
- const preview = document.getElementById('preview-connection');
- if (preview) {
- const path = this.createConnectionPath(
- this.connecting.startX,
- this.connecting.startY,
- x,
- y
- );
- preview.setAttribute('d', path);
- }
- }
- }
-
- onCanvasMouseUp(e) {
- if (this.connecting) {
- const preview = document.getElementById('preview-connection');
- if (preview) preview.remove();
-
- // Check if we're over a port
- const target = document.elementFromPoint(e.clientX, e.clientY);
- if (target && target.classList.contains('node-port')) {
- const targetNode = target.closest('.node');
- const targetNodeId = targetNode.id;
- const targetPortType = target.classList.contains('node-port-input') ? 'input' : 'output';
-
- // Validate connection
- if (this.canConnect(this.connecting.nodeId, this.connecting.portType, targetNodeId, targetPortType)) {
- this.addConnection(this.connecting.nodeId, this.connecting.portType, targetNodeId, targetPortType);
- this.updateCodePreview();
- }
- }
-
- this.connecting = null;
- }
-
- this.draggingNode = null;
- }
-
- canConnect(fromNodeId, fromPortType, toNodeId, toPortType) {
- // Can't connect to same node
- if (fromNodeId === toNodeId) return false;
-
- // Must connect output to input
- if (fromPortType === toPortType) return false;
-
- // Check if connection already exists
- const exists = this.connections.some(c =>
- (c.from.nodeId === fromNodeId && c.to.nodeId === toNodeId) ||
- (c.from.nodeId === toNodeId && c.to.nodeId === fromNodeId)
- );
-
- return !exists;
- }
-
- addConnection(fromNodeId, fromPortType, toNodeId, toPortType) {
- // Ensure output connects to input
- if (fromPortType === 'input') {
- [fromNodeId, toNodeId] = [toNodeId, fromNodeId];
- [fromPortType, toPortType] = [toPortType, fromPortType];
- }
-
- const connection = {
- from: { nodeId: fromNodeId, portType: fromPortType },
- to: { nodeId: toNodeId, portType: toPortType },
- element: null
- };
-
- const line = document.createElementNS('http://www.w3.org/2000/svg', 'path');
- line.classList.add('connection');
- line.classList.add('animated');
- this.svg.appendChild(line);
-
- connection.element = line;
- this.connections.push(connection);
-
- this.updateConnections();
- }
-
- updateConnections() {
- this.connections.forEach(conn => {
- const fromNode = this.nodes.find(n => n.id === conn.from.nodeId);
- const toNode = this.nodes.find(n => n.id === conn.to.nodeId);
-
- if (fromNode && toNode) {
- const fromPort = fromNode.element.querySelector('.node-port-output');
- const toPort = toNode.element.querySelector('.node-port-input');
-
- if (fromPort && toPort) {
- const fromRect = fromPort.getBoundingClientRect();
- const toRect = toPort.getBoundingClientRect();
- const canvasRect = this.canvas.getBoundingClientRect();
-
- const x1 = fromRect.left + fromRect.width / 2 - canvasRect.left;
- const y1 = fromRect.top + fromRect.height / 2 - canvasRect.top;
- const x2 = toRect.left + toRect.width / 2 - canvasRect.left;
- const y2 = toRect.top + toRect.height / 2 - canvasRect.top;
-
- const path = this.createConnectionPath(x1, y1, x2, y2);
- conn.element.setAttribute('d', path);
-
- // Store path coordinates for particle animation
- conn.pathCoords = { x1, y1, x2, y2 };
- }
- }
- });
- }
-
- // Animate gradient flow particles
- startGradientFlow() {
- if (this.flowInterval) return;
-
- this.flowInterval = setInterval(() => {
- this.connections.forEach(conn => {
- if (conn.pathCoords) {
- this.createFlowParticle(conn.pathCoords);
- }
- });
- }, 500);
- }
-
- stopGradientFlow() {
- if (this.flowInterval) {
- clearInterval(this.flowInterval);
- this.flowInterval = null;
- }
-
- // Remove all particles
- const particles = this.canvas.querySelectorAll('.flow-particle');
- particles.forEach(p => p.remove());
- }
-
- createFlowParticle(coords) {
- const particle = document.createElement('div');
- particle.className = 'flow-particle';
- particle.style.left = coords.x1 + 'px';
- particle.style.top = coords.y1 + 'px';
- this.canvas.appendChild(particle);
-
- // Animate along bezier curve
- const duration = 2000;
- const startTime = Date.now();
-
- const animate = () => {
- const elapsed = Date.now() - startTime;
- const t = Math.min(elapsed / duration, 1);
-
- if (t >= 1) {
- particle.remove();
- return;
- }
-
- // Calculate position on bezier curve
- const dx = Math.abs(coords.x2 - coords.x1);
- const cp1x = coords.x1 + dx * 0.5;
- const cp2x = coords.x2 - dx * 0.5;
-
- // Bezier curve formula
- const x = Math.pow(1-t, 3) * coords.x1 +
- 3 * Math.pow(1-t, 2) * t * cp1x +
- 3 * (1-t) * Math.pow(t, 2) * cp2x +
- Math.pow(t, 3) * coords.x2;
-
- const y = Math.pow(1-t, 3) * coords.y1 +
- 3 * Math.pow(1-t, 2) * t * coords.y1 +
- 3 * (1-t) * Math.pow(t, 2) * coords.y2 +
- Math.pow(t, 3) * coords.y2;
-
- particle.style.left = x + 'px';
- particle.style.top = y + 'px';
- particle.style.opacity = Math.sin(t * Math.PI);
-
- requestAnimationFrame(animate);
- };
-
- requestAnimationFrame(animate);
- }
-
- createConnectionPath(x1, y1, x2, y2) {
- const dx = Math.abs(x2 - x1);
- const cp1x = x1 + dx * 0.5;
- const cp2x = x2 - dx * 0.5;
- return `M ${x1} ${y1} C ${cp1x} ${y1}, ${cp2x} ${y2}, ${x2} ${y2}`;
- }
-
- selectNode(nodeId) {
- // Deselect previous
- if (this.selectedNode) {
- this.selectedNode.element.classList.remove('selected');
- }
-
- // Select new
- const node = this.nodes.find(n => n.id === nodeId);
- if (node) {
- node.element.classList.add('selected');
- this.selectedNode = node;
- this.showInspector(node);
- }
- }
-
- showInspector(node) {
- this.inspector.classList.add('active');
-
- const nodeData = window.CHAINABLE_FUNCTIONS[node.type];
- if (!nodeData || !nodeData.params) return;
-
- // Build inspector content
- let html = `${node.type.replace(/_/g, ' ')}
`;
-
- Object.entries(nodeData.params).forEach(([key, defaultValue]) => {
- const value = node.params[key] || defaultValue;
- html += `
-
- `;
- });
-
- this.inspector.innerHTML = html;
-
- // Set up parameter change listeners
- this.inspector.querySelectorAll('.inspector-input').forEach(input => {
- input.addEventListener('change', (e) => {
- const param = e.target.dataset.param;
- let value = e.target.value;
-
- // Convert to appropriate type
- if (e.target.type === 'number') {
- value = parseFloat(value);
- }
-
- node.params[param] = value;
- this.updateCodePreview();
- });
- });
- }
-
- updateCodePreview() {
- const codePreview = document.getElementById('code-output');
- if (!codePreview) return;
-
- // Generate Python code from pipeline
- const pipelineData = this.exportPipeline();
- const nodes = pipelineData.nodes.filter(n => n.type !== 'gradient_input');
-
- if (nodes.length === 0) {
- codePreview.textContent = '# Add optimizer components to build your pipeline';
- return;
- }
-
- let code = 'optimizer = BaseOpt(\n';
- code += ' model.parameters(),\n';
- code += ' lr=0.001,\n';
-
- // Add relevant parameters
- const params = {};
- nodes.forEach(node => {
- Object.assign(params, node.params);
- });
-
- Object.entries(params).forEach(([key, value]) => {
- if (typeof value === 'string') {
- code += ` ${key}="${value}",\n`;
- } else {
- code += ` ${key}=${value},\n`;
- }
- });
-
- code += ' fns=[\n';
- nodes.forEach(node => {
- code += ` ${node.type},\n`;
- });
- code += ' ]\n';
- code += ')';
-
- codePreview.textContent = code;
-
- // Update hidden pipeline data
- const pipelineInput = document.getElementById('pipeline-data');
- if (pipelineInput) {
- pipelineInput.value = JSON.stringify(pipelineData);
- }
- }
-
- exportPipeline() {
- return {
- nodes: this.nodes.map(n => ({
- id: n.id,
- type: n.type,
- x: n.x,
- y: n.y,
- params: n.params
- })),
- connections: this.connections.map(c => ({
- from: c.from,
- to: c.to
- }))
- };
- }
-
- loadRecipe(recipe) {
- // Clear existing nodes except gradient input
- this.nodes = this.nodes.filter(n => n.type === 'gradient_input');
- this.connections = [];
- this.selectedNode = null;
-
- // Clear canvas except gradient input
- const nodesToRemove = this.canvas.querySelectorAll('.node:not([id="node-0"])');
- nodesToRemove.forEach(n => n.remove());
-
- // Clear connections
- this.svg.innerHTML = '';
-
- // Add nodes from recipe
- let lastNodeId = 'node-0'; // gradient input
- let x = 300;
- const y = 300;
-
- recipe.forEach((item, index) => {
- const nodeId = this.addNode(item.name, x, y);
- const node = this.nodes.find(n => n.id === nodeId);
-
- // Set parameters
- if (item.params) {
- node.params = item.params;
- }
-
- // Connect to previous node
- this.addConnection(lastNodeId, 'output', nodeId, 'input');
-
- lastNodeId = nodeId;
- x += 200;
- });
-
- this.updateCodePreview();
- }
-}
-
-// Initialize when DOM is ready
-function initializeNodeEditor() {
- try {
- console.log('Attempting to initialize NodeEditor...');
- const editorElement = document.getElementById('node-editor');
- if (!editorElement) {
- console.error('node-editor element not found! Retrying in 100ms...');
- setTimeout(initializeNodeEditor, 100);
- return;
- }
- console.log('Found node-editor element:', editorElement);
- console.log('Editor dimensions:', editorElement.offsetWidth, 'x', editorElement.offsetHeight);
-
- if (window.nodeEditor) {
- console.log('NodeEditor already initialized');
- return;
- }
-
- window.nodeEditor = new NodeEditor('node-editor');
- console.log('NodeEditor initialized successfully');
- } catch (error) {
- console.error('Error initializing NodeEditor:', error);
- }
-}
-
-// Try multiple initialization methods
-document.addEventListener('DOMContentLoaded', initializeNodeEditor);
-// Also try after a delay in case content loads after DOMContentLoaded
-setTimeout(initializeNodeEditor, 100);
-setTimeout(initializeNodeEditor, 500);
-setTimeout(initializeNodeEditor, 1000);
diff --git a/pyproject.toml b/pyproject.toml
index 36916e8..7c2d78d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project]
name = "heavyball"
-description = "Compile-first PyTorch optimizer library — AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features"
-version = "2.3.2"
+description = "Compile-first PyTorch optimizer library - AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features"
+version = "3.0.0"
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
classifiers = ["Intended Audience :: Developers",
"Intended Audience :: Science/Research",
diff --git a/scripts/migrate_optimizer_state.py b/scripts/migrate_optimizer_state.py
index fb9f6e2..162538c 100644
--- a/scripts/migrate_optimizer_state.py
+++ b/scripts/migrate_optimizer_state.py
@@ -1,18 +1,21 @@
#!/usr/bin/env python3
"""
-Utility to migrate HeavyBall 1.x optimizer state dicts to the 2.0.0 layout.
+Utility to migrate HeavyBall optimizer state dicts to the current (3.x) layout.
-The script rewrites per-parameter state keys to the new transform-indexed names,
-reshapes state storage so each parameter-view owns its own dictionary, and
-injects the HeavyBall-specific metadata block expected by 2.0.0 optimizers.
+Supports two migration paths:
+- 1.x → 3.x: rewrites per-parameter state keys, reshapes storage, fixes param groups and metadata
+- 2.x → 3.x: renames foreach→multi_tensor in param groups, strips stale metadata
+
+Version auto-detection:
+- 1.x: flat per-parameter state (not nested by view index)
+- 2.x: nested state with 'foreach' in param groups
+- 3.x: nested state with 'multi_tensor' in param groups (already current)
"""
from __future__ import annotations
import functools
import importlib
-import pickle
-import random
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
@@ -20,6 +23,56 @@
import torch
import typer
+_CLASS_RENAMES = {
+ # Foreach* internal names → canonical 3.x names
+ "ForeachAdamW": "AdamW",
+ "ForeachNAdam": "NAdam",
+ "ForeachAdEMAMix": "AdEMAMix",
+ "ForeachAdamC": "AdamC",
+ "ForeachRMSprop": "RMSprop",
+ "ForeachSFAdamW": "SFAdamW",
+ "ForeachADOPT": "ADOPT",
+ "ForeachMuon": "Muon",
+ "ForeachLaProp": "LaProp",
+ "ForeachSOAP": "SOAP",
+ "ForeachSOAPNAdam": "SOAPNAdam",
+ "ForeachSOAPAdEMAMix": "SOAPAdEMAMix",
+ "ForeachSignLaProp": "SignLaProp",
+ "ForeachSOLP": "SOLP",
+ "ForeachPSGDKron": "PSGDKron",
+ "ForeachPSGDLRA": "PSGDLRA",
+ # Deleted subclasses → parent (all were __init__-only default overrides)
+ "PaLMForeachSFAdamW": "SFAdamW",
+ "PaLMForeachSOAP": "SOAP",
+ "PrecondScheduleForeachSOAP": "SOAP",
+ "PrecondSchedulePaLMForeachSOAP": "SOAP",
+ "ForeachPurePSGD": "PSGDKron",
+ "ForeachCachedPSGDKron": "PSGDKron",
+ "ForeachCachedDelayedPSGDKron": "PSGDKron",
+ "ForeachDelayedPSGD": "PSGDKron",
+ "ForeachCachedNewtonPSGD": "PSGDKron",
+ "NewtonHybrid2PSGDKron": "PSGDKron",
+ "ForeachDelayedPSGDLRA": "PSGDLRA",
+ "ForeachNewtonPSGDLRA": "PSGDLRA",
+ "NewtonHybrid2PSGDLRA": "PSGDLRA",
+ # 2.x public aliases that pointed to Foreach* classes (now removed)
+ "PaLMSOAP": "SOAP",
+ "PaLMSFAdamW": "SFAdamW",
+ "PalmForEachSoap": "SOAP",
+ "PrecondScheduleSOAP": "SOAP",
+ "PrecondSchedulePaLMSOAP": "SOAP",
+ "PurePSGD": "PSGDKron",
+ "DelayedPSGD": "PSGDKron",
+ "CachedPSGDKron": "PSGDKron",
+ "CachedDelayedPSGDKron": "PSGDKron",
+ "NewtonPSGDKron": "PSGDKron",
+ "DelayedPSGDLRA": "PSGDLRA",
+ "NewtonPSGDLRA": "PSGDLRA",
+}
+
+_REMOVED_GROUP_KEYS = frozenset({"stochastic_schedule"})
+_REMOVED_META_KEYS = frozenset({"stochastic_schedule", "precond_rng"})
+
@dataclass
class TransformMapping:
@@ -28,16 +81,37 @@ class TransformMapping:
transform_idx: int
+def _resolve_class_name(class_name: str) -> str:
+ return _CLASS_RENAMES.get(class_name, class_name)
+
+
def _load_optimizer_class(qualified_name: str):
if "." in qualified_name:
module_name, class_name = qualified_name.rsplit(".", 1)
else:
module_name, class_name = "heavyball", qualified_name
+ class_name = _resolve_class_name(class_name)
module = importlib.import_module(module_name)
try:
return getattr(module, class_name)
except AttributeError as exc:
- raise ValueError(f"Optimizer class '{qualified_name}' not found") from exc
+ raise ValueError(f"Optimizer class '{module_name}.{class_name}' not found") from exc
+
+
+def _detect_version(state_dict: Dict[str, Any]) -> int:
+ nested = False
+ for entry in state_dict.get("state", {}).values():
+ if isinstance(entry, dict):
+ nested = any(isinstance(v, dict) for v in entry.values())
+ if not nested:
+ return 1
+ break
+
+ groups = state_dict.get("param_groups", [])
+ has_multi_tensor = any(isinstance(g, dict) and "multi_tensor" in g for g in groups)
+ if nested:
+ return 3 if has_multi_tensor else 2
+ return 3 if has_multi_tensor else 1
def _guess_tensor_meta(state_entry: Dict[str, Any]) -> Tuple[Tuple[int, ...], torch.dtype]:
@@ -69,8 +143,9 @@ def _build_dummy_parameters(state: Dict[int, Dict[str, Any]], param_groups: Sequ
def _normalise_group_options(group: Dict[str, Any]) -> Dict[str, Any]:
options: Dict[str, Any] = {}
for key, value in group.items():
- if key == "params":
+ if key == "params" or key in _REMOVED_GROUP_KEYS:
continue
+ key = "multi_tensor" if key == "foreach" else key
if isinstance(value, list) and key in {"betas", "weight_decay_steps"}:
options[key] = tuple(value)
else:
@@ -104,9 +179,15 @@ def walk(queue: Iterable[Any]):
stack.append(current.fn)
elif isinstance(current, functools.partial): # type: ignore[name-defined]
stack.append(current.func)
- elif isinstance(current, C.Branch):
+ elif isinstance(current, C.Parallel):
for branch in current.branches:
stack.extend(branch)
+ elif isinstance(current, C.Route):
+ for _, fns in current.routes:
+ if fns:
+ stack.extend(fns)
+ if current.default:
+ stack.extend(current.default)
elif isinstance(current, (list, tuple)):
stack.extend(current)
@@ -180,7 +261,24 @@ def _migrate_single_state(entry: Dict[str, Any], mappings: List[TransformMapping
return migrated
-def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]:
+def _migrate_v2_to_v3(state_dict: Dict[str, Any]) -> None:
+ for group in state_dict.get("param_groups", []):
+ if not isinstance(group, dict):
+ continue
+ if "foreach" in group:
+ group["multi_tensor"] = group.pop("foreach")
+ for key in _REMOVED_GROUP_KEYS:
+ group.pop(key, None)
+
+ hb = state_dict.get("heavyball", {})
+ for key in _REMOVED_META_KEYS:
+ hb.pop(key, None)
+ inner = hb.get("inner_group", {})
+ for key in _REMOVED_META_KEYS:
+ inner.pop(key, None)
+
+
+def _migrate_v1_state(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]:
opt_cls = _load_optimizer_class(optimizer_class)
optimizer, _ = _instantiate_optimizer(opt_cls, old_state)
template = optimizer.state_dict()
@@ -199,11 +297,19 @@ def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[
heavyball_meta = migrated.setdefault("heavyball", template.get("heavyball", {}))
if "inner_group" not in heavyball_meta:
- heavyball_meta["inner_group"] = {"stochastic_schedule": None}
- if "stochastic_schedule" not in heavyball_meta:
- heavyball_meta["stochastic_schedule"] = None
- if "precond_rng" not in heavyball_meta:
- heavyball_meta["precond_rng"] = pickle.dumps(random.Random(0x12312))
+ heavyball_meta["inner_group"] = {}
+ return migrated
+
+
+def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]:
+ version = _detect_version(old_state)
+ if version == 1:
+ migrated = _migrate_v1_state(old_state, optimizer_class)
+ elif version == 2:
+ migrated = dict(old_state)
+ else:
+ return dict(old_state)
+ _migrate_v2_to_v3(migrated)
return migrated
@@ -221,7 +327,7 @@ def _resolve_state_container(root: Dict[str, Any], key_path: Sequence[str]) -> D
app = typer.Typer(help="Utilities for migrating HeavyBall optimizer checkpoints.")
-@app.command(help="Migrate a HeavyBall optimizer state dict to the 2.0.0 layout.")
+@app.command(help="Migrate a HeavyBall optimizer state dict to the current (3.x) layout.")
def migrate(
checkpoint: Path = typer.Argument(
...,
@@ -234,7 +340,7 @@ def migrate(
),
optimizer_class: str = typer.Argument(
...,
- help="Optimizer class to instantiate (e.g., heavyball.ForeachAdamW)",
+ help="Optimizer class name (e.g., heavyball.AdamW). Old names like ForeachAdamW are resolved automatically.",
),
state_key: str = typer.Option(
"optimizer",
diff --git a/test/test_ademamix.py b/test/test_ademamix.py
index 847a290..49e53de 100644
--- a/test/test_ademamix.py
+++ b/test/test_ademamix.py
@@ -101,7 +101,7 @@ def test_ademamix_matches_reference_math():
alpha_warmup = 4
param = torch.nn.Parameter(initial.clone())
- optimizer = heavyball.ForeachAdEMAMix(
+ optimizer = heavyball.AdEMAMix(
[param],
lr=lr,
betas=betas,
@@ -110,7 +110,7 @@ def test_ademamix_matches_reference_math():
alpha=alpha,
beta3_warmup=beta3_warmup,
alpha_warmup=alpha_warmup,
- foreach=False,
+ multi_tensor=False,
)
for grad in grads:
@@ -149,7 +149,7 @@ def test_soap_ademamix_projects_gradients_into_eigenbasis():
torch.manual_seed(7)
param = torch.nn.Parameter(torch.randn(2, 2))
- optimizer = heavyball.ForeachSOAPAdEMAMix([param], lr=0.01, foreach=False)
+ optimizer = heavyball.SOAPAdEMAMix([param], lr=0.01, multi_tensor=False)
# First call initializes the SOAP preconditioner state without applying an update.
param.grad = torch.randn_like(param)
diff --git a/test/test_chainable_cpu.py b/test/test_chainable_cpu.py
index 124cbe4..ed2da33 100644
--- a/test/test_chainable_cpu.py
+++ b/test/test_chainable_cpu.py
@@ -36,7 +36,7 @@ def negate(_, __, update, ___, ____):
def merge_fn(outputs):
return [sum(vals) / len(vals) for vals in zip(*outputs)]
- branch = C.Branch([[double], [negate]], merge_fn)
+ branch = C.Parallel([[double], [negate]], merge_fn)
update = [torch.ones(2)]
grad = [torch.ones(2)]
@@ -70,46 +70,37 @@ def state_fn(_x):
# Optimizers whose chains are purely elementwise must NOT need gather
_EXPECT_NO_GATHER = {
"SGD",
- "ForeachAdamW",
- "ForeachNAdam",
- "ForeachAdEMAMix",
+ "AdamW",
+ "NAdam",
+ "AdEMAMix",
"UnscaledAdamW",
- "ForeachAdamC",
- "ForeachRMSprop",
- "ForeachSFAdamW",
- "ForeachADOPT",
- "ForeachLaProp",
- "PaLMForeachSFAdamW",
+ "AdamC",
+ "RMSprop",
+ "SFAdamW",
+ "ADOPT",
+ "LaProp",
}
# Optimizers whose chains use shape-dependent or global-reduction ops must need gather
_EXPECT_GATHER = {
- "ForeachSOAP",
- "ForeachSOAPNAdam",
- "ForeachSOAPAdEMAMix",
- "ForeachSOLP",
- "ForeachMuon",
+ "SOAP",
+ "SOAPNAdam",
+ "SOAPAdEMAMix",
+ "SOLP",
+ "Muon",
"MuonLaProp",
"OrthoLaProp",
"LaPropOrtho",
- "ForeachPSGDKron",
- "ForeachPurePSGD",
- "ForeachCachedPSGDKron",
- "ForeachDelayedPSGD",
- "ForeachCachedDelayedPSGDKron",
- "ForeachCachedNewtonPSGD",
- "NewtonHybrid2PSGDKron",
- "ForeachPSGDLRA",
- "ForeachDelayedPSGDLRA",
- "ForeachNewtonPSGDLRA",
- "NewtonHybrid2PSGDLRA",
+ "PSGDKron",
+ "LATHER",
+ "PSGDLRA",
+ "PSGDPRO",
"SUDSAdamW",
"Scion",
- "ForeachSignLaProp",
+ "SignLaProp",
"MSAMLaProp",
- "PaLMForeachSOAP",
- "PrecondScheduleForeachSOAP",
- "PrecondSchedulePaLMForeachSOAP",
+ "HyperBallAdamW",
+ "MuonAdamW",
}
_SKIP_INSTANTIATE = {"SplitOpt", "SAMWrapper"}
@@ -120,7 +111,7 @@ def state_fn(_x):
@pytest.mark.parametrize("opt_name", _ALL_OPTS)
def test_needs_gather_flag(opt_name):
params = [torch.nn.Parameter(torch.randn(4, 4))]
- extra = {"max_lr": 0.0025} if opt_name == "ForeachAdamC" else {}
+ extra = {"max_lr": 0.0025} if opt_name == "AdamC" else {}
opt = getattr(heavyball, opt_name)(params, lr=1e-3, **extra)
if opt_name in _EXPECT_NO_GATHER:
assert not opt._needs_gather, f"{opt_name} should be elementwise (no gather needed)"
diff --git a/test/test_compile_step.py b/test/test_compile_step.py
new file mode 100644
index 0000000..ce0a8e5
--- /dev/null
+++ b/test/test_compile_step.py
@@ -0,0 +1,114 @@
+import inspect
+
+import pytest
+import torch
+
+import heavyball
+from heavyball.chainable import ChainOpt, WarmupGuard, _walk_fns
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+EXTRA_KWARGS = {
+ "AdamC": {"max_lr": 0.01},
+}
+
+
+def _optimizer_params():
+ seen = set()
+ params = []
+ for name in heavyball.__all__:
+ if not hasattr(heavyball, name):
+ continue
+ obj = getattr(heavyball, name)
+ if not inspect.isclass(obj):
+ continue
+ if not issubclass(obj, torch.optim.Optimizer):
+ continue
+ ident = id(obj)
+ if ident in seen:
+ continue
+ seen.add(ident)
+ if name == "SplitOpt":
+ params.append(
+ pytest.param(name, obj, id=name, marks=pytest.mark.skip(reason="SplitOpt requires dict specs"))
+ )
+ continue
+ params.append(pytest.param(name, obj, id=name))
+ return params
+
+
+def _make_model():
+ return torch.nn.Sequential(
+ torch.nn.Linear(8, 16),
+ torch.nn.Tanh(),
+ torch.nn.Linear(16, 4),
+ ).to(DEVICE)
+
+
+def _run_steps(model, optimizer, n=5, seed=0xDEADBEEF):
+ torch.manual_seed(seed)
+ for _ in range(n):
+
+ def closure():
+ optimizer.zero_grad(set_to_none=True)
+ data = torch.randn(4, 8, device=DEVICE)
+ target = torch.randn(4, 4, device=DEVICE)
+ loss = torch.nn.functional.mse_loss(model(data), target)
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+
+@pytest.mark.parametrize("opt_name,opt_cls", _optimizer_params())
+def test_compile_step_matches_eager(opt_name, opt_cls):
+ """compile_step=True must produce identical parameters to compile_step=False."""
+ sig = inspect.signature(opt_cls.__init__)
+ if "compile_step" not in sig.parameters:
+ pytest.skip("optimizer does not accept compile_step")
+
+ kwargs = dict(EXTRA_KWARGS.get(opt_name, {}))
+
+ torch.manual_seed(0xDEADBEEF)
+ model_ref = _make_model()
+ model_test = _make_model()
+ model_test.load_state_dict(model_ref.state_dict())
+
+ opt_ref = opt_cls(model_ref.parameters(), compile_step=False, **kwargs)
+ opt_test = opt_cls(model_test.parameters(), compile_step=True, **kwargs)
+
+ _run_steps(model_ref, opt_ref)
+ _run_steps(model_test, opt_test)
+
+ for p_ref, p_test in zip(model_ref.parameters(), model_test.parameters()):
+ diff = (p_ref.data - p_test.data).abs().max().item()
+ assert diff < 1e-4, f"compile_step diverged: max_diff={diff}"
+
+
+def _max_warmup(opt):
+ return max((len(ft.warmup_fns) for ft in _walk_fns(opt.fns) if isinstance(ft, WarmupGuard)), default=0)
+
+
+@pytest.mark.parametrize("opt_name,opt_cls", _optimizer_params())
+def test_needs_init_clears(opt_name, opt_cls):
+ """_needs_init must become False after max_warmup + 1 steps for all ChainOpt optimizers.
+
+ Catches bugs where Route-based or warmup_guard-based optimizers permanently
+ force eager mode because different params accumulate different is_initialized
+ sets that never individually cover _transform_ids.
+ """
+ if not issubclass(opt_cls, ChainOpt):
+ pytest.skip("not a ChainOpt")
+
+ kwargs = dict(EXTRA_KWARGS.get(opt_name, {}))
+ model = _make_model()
+ opt = opt_cls(model.parameters(), **kwargs)
+ n = _max_warmup(opt) + 1
+
+ _run_steps(model, opt, n=n)
+
+ for group in opt.param_groups:
+ state = [opt.state_(p) for p in group["params"]]
+ assert not opt._needs_init(state), (
+ f"{opt_name}: _needs_init stuck True after {n} steps | compile_step will never engage"
+ )
diff --git a/test/test_cpu_features.py b/test/test_cpu_features.py
index e241f98..0a8b746 100644
--- a/test/test_cpu_features.py
+++ b/test/test_cpu_features.py
@@ -44,9 +44,9 @@ def _make_batch(
@pytest.mark.parametrize(
"opt_name",
[
- "ForeachSOAP",
+ "SOAP",
"Muon",
- "ForeachAdamW",
+ "AdamW",
],
)
def test_selected_optimizers_run_on_cpu(opt_name: str) -> None:
@@ -92,8 +92,8 @@ def test_mars_flag_changes_behavior() -> None:
model_a, data, target = _make_batch()
model_b = deepcopy(model_a)
- opt_a = heavyball.ForeachAdamW(model_a.parameters(), mars=False, warmup_steps=0)
- opt_b = heavyball.ForeachAdamW(model_b.parameters(), mars=True, warmup_steps=0)
+ opt_a = heavyball.AdamW(model_a.parameters(), mars=False, warmup_steps=0)
+ opt_b = heavyball.AdamW(model_b.parameters(), mars=True, warmup_steps=0)
init = [param.detach().clone() for param in model_a.parameters()]
@@ -112,7 +112,7 @@ def test_mars_flag_changes_behavior() -> None:
def test_sam_wrapper_requires_closure() -> None:
model = nn.Linear(4, 2)
- base = heavyball.ForeachAdamW(model.parameters())
+ base = heavyball.AdamW(model.parameters())
wrapper = heavyball.SAMWrapper(model.parameters(), wrapped_optimizer=base)
with pytest.raises(ValueError):
@@ -132,3 +132,83 @@ def closure():
after = [param.detach() for param in model.parameters()]
diff = torch.cat([(a - b).reshape(-1) for a, b in zip(after, before, strict=True)])
assert diff.norm().item() > 0.0
+
+
+def test_multiple_param_groups_keep_updating() -> None:
+ p1 = nn.Parameter(torch.zeros(()))
+ p2 = nn.Parameter(torch.zeros(()))
+ opt = heavyball.SGD(
+ [
+ {"params": [p1]},
+ {"params": [p2]},
+ ],
+ lr=0.1,
+ beta=0.0,
+ warmup_steps=0,
+ )
+
+ for _ in range(3):
+ p1.grad = torch.ones_like(p1)
+ p2.grad = torch.full_like(p2, 2.0)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ assert torch.allclose(p1.detach(), torch.tensor(-0.3))
+ assert torch.allclose(p2.detach(), torch.tensor(-0.6))
+
+
+def test_group_step_does_not_reset_when_active_param_changes() -> None:
+ p1 = nn.Parameter(torch.zeros(()))
+ p2 = nn.Parameter(torch.zeros(()))
+ opt = heavyball.SGD([p1, p2], lr=0.1, beta=0.0, warmup_steps=3)
+
+ p1.grad = torch.ones_like(p1)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ p2.grad = torch.ones_like(p2)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ assert p1.item() == pytest.approx(-0.025)
+ assert p2.item() == pytest.approx(-0.05)
+
+
+def test_string_clipping_shorthands_match_public_api() -> None:
+ model, data, target = _make_batch()
+ opt = heavyball.SGD(
+ model.parameters(),
+ lr=1e-3,
+ beta=0.0,
+ gradient_clipping="l2_clip_",
+ update_clipping="trust_region_clip_",
+ )
+
+ loss = _train_once(opt, model, data, target, steps=2)
+ assert torch.isfinite(torch.tensor(loss))
+
+
+@pytest.mark.parametrize("opt_cls", [heavyball.SFAdamW, heavyball.MSAMLaProp])
+def test_mode_switches_are_idempotent(opt_cls) -> None:
+ p = nn.Parameter(torch.tensor([1.0, -1.0]))
+ opt = opt_cls([p], lr=1e-2)
+
+ p.grad = torch.ones_like(p)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ opt.eval()
+ eval_once = p.detach().clone()
+ opt.eval()
+ eval_twice = p.detach().clone()
+
+ assert torch.allclose(eval_once, eval_twice)
+ assert opt.param_groups[0]["train_mode"] is False
+
+ opt.train()
+ train_once = p.detach().clone()
+ opt.train()
+ train_twice = p.detach().clone()
+
+ assert torch.allclose(train_once, train_twice)
+ assert opt.param_groups[0]["train_mode"] is True
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 1b60e7b..d96d9d1 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -17,14 +17,13 @@
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
]
-_EXTRA_KWARGS = {"ForeachAdamC": {"max_lr": 0.0025}}
+_EXTRA_KWARGS = {"AdamC": {"max_lr": 0.0025}}
_MODEL_SEED = 42
_DATA_SEED = 0xABCD
# LRA builds one preconditioner over all grads, under FSDP each rank only has a subset
_FSDP_SKIP = {
- "ForeachPSGDLRA": "LRA preconditioner scope differs under FSDP",
- "ForeachDelayedPSGDLRA": "LRA preconditioner scope differs under FSDP",
+ "PSGDLRA": "LRA preconditioner scope differs under FSDP",
}
# torch.compile(dynamic=False) specializes on list length → different kernels per rank
@@ -39,20 +38,20 @@
_INTEGRATION_OPTS = [
n
for n in [
- "ForeachAdamW",
- "ForeachSOAP",
- "ForeachMuon",
- "ForeachPSGDKron",
- "ForeachPurePSGD",
+ "AdamW",
+ "SOAP",
+ "Muon",
+ "PSGDKron",
"Scion",
- "ForeachLaProp",
+ "LaProp",
"MuonLaProp",
- "ForeachSOAPNAdam",
- "ForeachCachedPSGDKron",
+ "SOAPNAdam",
]
if n in REPRESENTATIVE_OPTS and n not in _FSDP_SKIP
]
+_OWNER_EDGE_OPTS = [n for n in ["Muon"] if n in REPRESENTATIVE_OPTS and n not in _FSDP_SKIP]
+
def _set_cache(cache_dir, compile_mode="default"):
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
@@ -104,6 +103,14 @@ def _make_integration_model():
)
+def _make_owner_edge_model():
+ return nn.Sequential(
+ nn.Linear(32, 32, bias=False),
+ nn.ReLU(),
+ nn.Linear(32, 1, bias=False),
+ )
+
+
def _make_data(dim=32, n=4):
torch.manual_seed(_DATA_SEED)
return [torch.randn(4, dim, device="cuda") for _ in range(n)]
@@ -121,6 +128,10 @@ def _make_integration_data():
return _make_data(64, 8)
+def _make_owner_edge_data():
+ return _make_data(32, 6)
+
+
def _train(model, opt, data):
for x in data:
model(x).mean().backward()
@@ -266,7 +277,7 @@ def _assert_close(ref, result, label, rtol=0, atol=0):
assert False, f"{label}: param {i} diverged beyond 1 ULP ({n} elements, max |diff|={worst:.2e})"
-def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label):
+def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label, world_size=2, tol=None):
cache_dir = tempfile.mkdtemp(prefix=f"hb_{label}_{opt_name}_")
ref_path = str(tmp_path / "ref.pt")
mp.spawn(_ref_worker, args=(opt_name, ref_path, cache_dir, None, model_fn, data_fn), nprocs=1, join=True)
@@ -274,12 +285,14 @@ def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label):
result_path = str(tmp_path / "result.pt")
mp.spawn(
_fsdp_worker,
- args=(2, str(tmp_path / "store"), opt_name, result_path, cache_dir, None, model_fn, data_fn),
- nprocs=2,
+ args=(world_size, str(tmp_path / "store"), opt_name, result_path, cache_dir, None, model_fn, data_fn),
+ nprocs=world_size,
join=True,
)
- tol = dict(rtol=1e-2, atol=1e-4) if opt_name in _FSDP_PSGD else {}
- _assert_close(ref, torch.load(result_path, weights_only=True), f"{label}/{opt_name}", **tol)
+ base_tol = dict(rtol=1e-2, atol=1e-4) if opt_name in _FSDP_PSGD else {}
+ if tol is not None:
+ base_tol.update({k: max(base_tol.get(k, 0), v) for k, v in tol.items()})
+ _assert_close(ref, torch.load(result_path, weights_only=True), f"{label}/{opt_name}", **base_tol)
@pytest.mark.parametrize("opt_name", REPRESENTATIVE_OPTS)
@@ -339,3 +352,16 @@ def test_fsdp_misaligned(opt_name, tmp_path):
@pytest.mark.parametrize("opt_name", _INTEGRATION_OPTS)
def test_fsdp_integration(opt_name, tmp_path):
_run_fsdp_test(opt_name, tmp_path, _make_integration_model, _make_integration_data, "FSDP-integ")
+
+
+@pytest.mark.parametrize("opt_name", _OWNER_EDGE_OPTS)
+def test_fsdp_owner_edge(opt_name, tmp_path):
+ _run_fsdp_test(
+ opt_name,
+ tmp_path,
+ _make_owner_edge_model,
+ _make_owner_edge_data,
+ "FSDP-owner3",
+ world_size=3,
+ tol=dict(atol=5e-8),
+ )
diff --git a/test/test_ecc.py b/test/test_ecc.py
index 97c0eab..45c4758 100644
--- a/test/test_ecc.py
+++ b/test/test_ecc.py
@@ -14,13 +14,13 @@
ULP_MODES = ["bf16+8", "bf16+16", "fp16+8", "fp16+16"]
_OPTIMIZERS = [
- (heavyball.ForeachAdamW, 5e-2, {}),
- (heavyball.ForeachADOPT, 5e-2, {}),
- (heavyball.ForeachNAdam, 1e-2, {}),
- (heavyball.ForeachLaProp, 5e-2, {}),
- (heavyball.ForeachAdEMAMix, 5e-2, {"betas": (0.9, 0.999, 0.9999)}),
- (heavyball.ForeachRMSprop, 1e-2, {}),
- (heavyball.PaLMForeachSFAdamW, 1e-2, {}),
+ (heavyball.AdamW, 5e-2, {}),
+ (heavyball.ADOPT, 5e-2, {}),
+ (heavyball.NAdam, 1e-2, {}),
+ (heavyball.LaProp, 5e-2, {}),
+ (heavyball.AdEMAMix, 5e-2, {"betas": (0.9, 0.999, 0.9999)}),
+ (heavyball.RMSprop, 1e-2, {}),
+ (heavyball.SFAdamW, 1e-2, {}),
]
@@ -125,12 +125,12 @@ def test_ecc_convergence(opt_cls, lr, extra_kw, mode):
def test_param_ecc_convergence(combined):
set_torch()
data, target = _problem()
- m0, o0 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2)
+ m0, o0 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2)
losses_base = _train(m0, o0, data, target, 200)
kw = {"param_ecc": "bf16+8"}
if combined:
kw["ecc"] = "bf16+8"
- m1, o1 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, **kw)
+ m1, o1 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, **kw)
losses_ecc = _train(m1, o1, data, target, 200)
p = list(m1.parameters())[0]
assert p.dtype == torch.bfloat16
@@ -148,7 +148,7 @@ def test_state_layout_and_invariants(mode):
cfg = ECCConfig(mode)
torch.manual_seed(42)
model = nn.Linear(32, 16, bias=False, device="cuda")
- opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, ecc=mode)
+ opt = heavyball.AdamW(model.parameters(), lr=1e-2, ecc=mode)
x = torch.randn(4, 32, device="cuda")
for _ in range(3):
model(x).sum().backward()
@@ -174,7 +174,7 @@ def test_ademamix_three_vars():
set_torch()
torch.manual_seed(42)
model = nn.Linear(64, 32, bias=False, device="cuda")
- opt = heavyball.ForeachAdEMAMix(model.parameters(), lr=5e-2, betas=(0.9, 0.999, 0.9999), ecc="bf16+8")
+ opt = heavyball.AdEMAMix(model.parameters(), lr=5e-2, betas=(0.9, 0.999, 0.9999), ecc="bf16+8")
x = torch.randn(4, 64, device="cuda")
for _ in range(10):
model(x).sum().backward()
@@ -193,7 +193,7 @@ def test_ademamix_three_vars():
def test_combined_ecc_dtypes():
set_torch()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 32, 16, 1e-2, ecc="bf16+16", param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 32, 16, 1e-2, ecc="bf16+16", param_ecc="bf16+8")
data, target = _problem(in_dim=32, out_dim=16, n=8)
_train(m, o, data, target, 100)
p = list(m.parameters())[0]
@@ -211,7 +211,7 @@ def test_shapes_and_bias():
set_torch()
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(32, 16, bias=True), nn.Linear(16, 4, bias=False)).cuda()
- opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, ecc="bf16+8")
+ opt = heavyball.AdamW(model.parameters(), lr=1e-2, ecc="bf16+8")
data, target = torch.randn(16, 32, device="cuda"), torch.randn(16, 4, device="cuda")
losses = _train(model, opt, data, target, 50)
for p in model.parameters():
@@ -225,9 +225,9 @@ def test_shapes_and_bias():
def test_foreach_false():
set_torch()
data, target = _problem(in_dim=32, out_dim=4, n=16)
- m_fe, o_fe = _model_opt(heavyball.ForeachAdamW, 32, 4, 1e-2, ecc="bf16+8", foreach=True)
+ m_fe, o_fe = _model_opt(heavyball.AdamW, 32, 4, 1e-2, ecc="bf16+8", multi_tensor=True)
losses_fe = _train(m_fe, o_fe, data, target, 50)
- m_nf, o_nf = _model_opt(heavyball.ForeachAdamW, 32, 4, 1e-2, ecc="bf16+8", foreach=False)
+ m_nf, o_nf = _model_opt(heavyball.AdamW, 32, 4, 1e-2, ecc="bf16+8", multi_tensor=False)
losses_nf = _train(m_nf, o_nf, data, target, 50)
assert losses_nf[-1] < losses_nf[0] * 0.5
assert 0.3 < losses_nf[-1] / max(losses_fe[-1], 1e-12) < 3.0
@@ -239,7 +239,7 @@ def test_param_groups():
set_torch()
torch.manual_seed(42)
m1, m2 = nn.Linear(16, 8, bias=False, device="cuda"), nn.Linear(8, 4, bias=False, device="cuda")
- opt = heavyball.ForeachAdamW(
+ opt = heavyball.AdamW(
[
{"params": m1.parameters(), "ecc": "bf16+8"},
{"params": m2.parameters()},
@@ -261,7 +261,7 @@ def test_zero_gradients():
set_torch()
torch.manual_seed(42)
p = nn.Parameter(torch.randn(16, 8, device="cuda"))
- opt = heavyball.ForeachAdamW([p], lr=1e-2, ecc="bf16+8")
+ opt = heavyball.AdamW([p], lr=1e-2, ecc="bf16+8")
for _ in range(10):
p.grad = torch.zeros_like(p)
opt.step()
@@ -276,10 +276,10 @@ def test_state_save_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt, sd_model = deepcopy(o.state_dict()), deepcopy(m.state_dict())
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
losses_after = _train(m2, o2, data, target, 10)
@@ -326,9 +326,9 @@ def _measure_peak(cls, n, lr, ecc=None, param_ecc=None, steps=3):
@pytest.mark.parametrize(
"cls,lr",
[
- (heavyball.ForeachAdamW, 1e-3),
- (heavyball.PaLMForeachSFAdamW, 1e-2),
- (heavyball.ForeachRMSprop, 1e-2),
+ (heavyball.AdamW, 1e-3),
+ (heavyball.SFAdamW, 1e-2),
+ (heavyball.RMSprop, 1e-2),
],
ids=["AdamW", "SFAdamW", "RMSprop"],
)
@@ -346,7 +346,7 @@ def test_ecc_peak_memory(cls, lr, mode):
@pytest.mark.parametrize("combined", [False, True], ids=["param_only", "state+param"])
def test_param_ecc_peak_memory(combined):
n = 2**24
- cls, lr = heavyball.PaLMForeachSFAdamW, 1e-2
+ cls, lr = heavyball.SFAdamW, 1e-2
pre_base, peak_base = _measure_peak(cls, n, lr)
ecc = "bf16+8" if combined else None
pre_ecc, peak_ecc = _measure_peak(cls, n, lr, ecc=ecc, param_ecc="bf16+8")
@@ -357,19 +357,19 @@ def test_param_ecc_peak_memory(combined):
def test_ecc_live_path_nonzero_correction():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.ForeachAdamW, 16, 8, 5e-2, ecc="bf16+8")
+ m, o = _model_opt(heavyball.AdamW, 16, 8, 5e-2, ecc="bf16+8")
_train(m, o, data, target, 20)
p = list(m.parameters())[0]
st, ecc_keys = _ecc_keys(o, p)
for ek in ecc_keys:
- assert st[ek].any(), f"ECC correction '{ek}' is all zeros — stochastic_round_ likely mutating source"
+ assert st[ek].any(), f"ECC correction '{ek}' is all zeros - stochastic_round_ likely mutating source"
del m, o
clean()
def test_lerp_returns_valid_fp32():
set_torch()
- m, o = _model_opt(heavyball.ForeachAdamW, 16, 8, 5e-2, ecc="bf16+8")
+ m, o = _model_opt(heavyball.AdamW, 16, 8, 5e-2, ecc="bf16+8")
data, target = _problem()
_train(m, o, data, target, 10)
p = list(m.parameters())[0]
@@ -398,7 +398,7 @@ def test_param_ecc_merge_dims():
nn.Flatten(),
nn.Linear(64, 8, bias=False),
).cuda()
- opt = heavyball.ForeachSOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
+ opt = heavyball.SOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
conv_p = list(model.parameters())[0]
assert conv_p.dtype == torch.bfloat16
st = _flat_state(opt, conv_p)
@@ -417,7 +417,7 @@ def test_param_ecc_merge_dims():
def test_param_ecc_dtype_at_construction():
set_torch()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
p = list(m.parameters())[0]
assert p.dtype == torch.bfloat16, "param should be bf16 immediately after construction"
st = _flat_state(o, p)
@@ -431,11 +431,11 @@ def test_param_ecc_save_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt, sd_model = deepcopy(o.state_dict()), deepcopy(m.state_dict())
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
p2 = list(m2.parameters())[0]
@@ -452,7 +452,7 @@ def test_param_ecc_partial_state_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
@@ -460,7 +460,7 @@ def test_param_ecc_partial_state_restore():
for idx_state in param_state.values():
if isinstance(idx_state, dict):
idx_state.pop("param::ecc", None)
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
st2 = _flat_state(o2, list(m2.parameters())[0])
@@ -478,12 +478,12 @@ def test_param_ecc_empty_state_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
sd_opt["state"] = {}
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
p2 = list(m2.parameters())[0]
@@ -508,7 +508,7 @@ def test_param_ecc_merged_view_partial_restore():
nn.Flatten(),
nn.Linear(64, 8, bias=False),
).cuda()
- opt = heavyball.ForeachSOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
+ opt = heavyball.SOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
data = torch.randn(4, 3, 8, 8, device="cuda")
target = torch.randn(4, 8, device="cuda")
conv_p = list(model.parameters())[0]
@@ -529,7 +529,7 @@ def test_param_ecc_merged_view_partial_restore():
nn.Flatten(),
nn.Linear(64, 8, bias=False),
).cuda()
- opt2 = heavyball.ForeachSOAP(model2.parameters(), lr=1e-3, param_ecc="bf16+8")
+ opt2 = heavyball.SOAP(model2.parameters(), lr=1e-3, param_ecc="bf16+8")
model2.load_state_dict(sd_model)
opt2.load_state_dict(sd_opt)
conv_p2 = list(model2.parameters())[0]
@@ -549,7 +549,7 @@ def test_param_ecc_merged_view_partial_restore():
def test_optimizer_kwargs_not_in_param_groups():
set_torch()
p = torch.nn.Parameter(torch.randn(4, 4, device="cuda"))
- o = heavyball.ForeachAdamW([p], lr=1e-3, compile_step=True, promote=True)
+ o = heavyball.AdamW([p], lr=1e-3, compile_step=True, promote=True)
assert o.compile_step is True
assert o.promote is True
assert "compile_step" not in o.param_groups[0]
@@ -564,7 +564,7 @@ def test_param_ecc_load_order_model_before_optimizer():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
@@ -573,7 +573,7 @@ def test_param_ecc_load_order_model_before_optimizer():
if isinstance(idx_state, dict):
idx_state.pop("param::ecc", None)
# model-first: load model, then optimizer
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
p2 = list(m2.parameters())[0]
@@ -594,7 +594,7 @@ def test_param_ecc_load_order_optimizer_before_model():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
@@ -603,7 +603,7 @@ def test_param_ecc_load_order_optimizer_before_model():
if isinstance(idx_state, dict):
idx_state.pop("param::ecc", None)
# optimizer-first: load optimizer, then model
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
o2.load_state_dict(sd_opt)
m2.load_state_dict(sd_model)
p2 = list(m2.parameters())[0]
diff --git a/test/test_foreach.py b/test/test_foreach.py
index de3fa2a..06dd9f4 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -57,12 +57,12 @@ def test_foreach(
losses = [[], []]
for i in range(total_runs):
- for foreach in [True, False]:
- lss, pk = losses[int(foreach)], peaks[int(foreach)]
+ for multi_tensor in [True, False]:
+ lss, pk = losses[int(multi_tensor)], peaks[int(multi_tensor)]
torch.manual_seed(0x2131290)
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
- o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach)
+ o = get_optim(opt, model.parameters(), lr=1e-3, multi_tensor=multi_tensor)
clean()
@@ -92,17 +92,17 @@ def test_foreach(
cutoff = warmup_runs * iterations
losses = [loss_list[cutoff:] for loss_list in losses]
- for peak_no_foreach, peak_foreach in zip(*peaks):
- assert peak_no_foreach < peak_foreach
+ for peak_single, peak_multi in zip(*peaks):
+ assert peak_single < peak_multi
- # no-foreach LRA is a different optimizer (per-parameter LRA vs global LRA),
+ # single-tensor LRA is a different optimizer (per-parameter LRA vs global LRA),
# so we only check that both converge, not that they match.
if "LRA" in opt.__name__:
return
- for loss_no_foreach, loss_foreach in zip(*losses):
- if torch.isnan(loss_no_foreach) and torch.isnan(loss_foreach):
+ for loss_single, loss_multi in zip(*losses):
+ if torch.isnan(loss_single) and torch.isnan(loss_multi):
continue
# increase error tolerance for PSGD, as we have different RNGs -> expected differences
- assert torch.allclose(loss_no_foreach, loss_foreach, rtol=0.01 if "PSGD" in opt.__name__ else 1e-5)
+ assert torch.allclose(loss_single, loss_multi, rtol=0.01 if "PSGD" in opt.__name__ else 1e-5)
diff --git a/test/test_helpers_cpu.py b/test/test_helpers_cpu.py
index bc5952c..5a7ed4e 100644
--- a/test/test_helpers_cpu.py
+++ b/test/test_helpers_cpu.py
@@ -3,6 +3,7 @@
import numpy as np
import optuna
import pandas as pd
+import pytest
import torch
from optuna.distributions import FloatDistribution, IntDistribution
from optuna.samplers import RandomSampler
@@ -81,6 +82,22 @@ def _dummy_candidates(params, values, *_args):
assert 0.0 <= suggestion["width"] <= 1.0
+def test_helper_samplers_reject_removed_compat_kwargs():
+ search_space = {"width": FloatDistribution(0.0, 1.0)}
+
+ with pytest.raises(TypeError):
+ helpers.BoTorchSampler(search_space, consider_running_trials=True)
+
+ with pytest.raises(TypeError):
+ helpers.HEBOSampler(search_space, constant_liar=True)
+
+ with pytest.raises(TypeError):
+ helpers.ImplicitNaturalGradientSampler(search_space, warn_independent_sampling=False)
+
+ with pytest.raises(TypeError):
+ helpers.AutoSampler(search_space=search_space, constraints_func=lambda *_args: None)
+
+
def test_hebo_sampler_observe_and_sample(monkeypatch):
class DummyHEBO:
def __init__(self, *_args, **_kwargs):
diff --git a/test/test_memory_leak.py b/test/test_memory_leak.py
index f7088dc..f6c6f50 100644
--- a/test/test_memory_leak.py
+++ b/test/test_memory_leak.py
@@ -34,7 +34,7 @@ def forward(self, x):
def test_memory(
- opt: str = "NewtonHybrid2PSGDKron",
+ opt: str = "PSGDKron",
size: int = 64,
depth: int = 2,
mars: bool = False,
diff --git a/test/test_merge.py b/test/test_merge.py
index 90bafb2..d1a660f 100644
--- a/test/test_merge.py
+++ b/test/test_merge.py
@@ -18,7 +18,7 @@ def forward(self, inp):
return self.weight.mean() * inp
-@pytest.mark.parametrize("opt", ["ForeachPSGDKron"])
+@pytest.mark.parametrize("opt", ["PSGDKron"])
@pytest.mark.parametrize("size", [(16, 16, 16, 16), (4, 4, 4, 4), (512, 1, 128), (32128, 768)])
@pytest.mark.parametrize("merge,split", [(False, False), (True, False), (True, True)])
def test_merge(
diff --git a/test/test_migrate_cli.py b/test/test_migrate_cli.py
index 1554d09..0130305 100644
--- a/test/test_migrate_cli.py
+++ b/test/test_migrate_cli.py
@@ -1,5 +1,7 @@
import importlib.util
import pathlib
+import pickle
+import random
import sys
import pytest
@@ -14,165 +16,723 @@
sys.modules[MODULE_NAME] = migrate_script
SPEC.loader.exec_module(migrate_script) # type: ignore[arg-type]
+# ---------------------------------------------------------------------------
+# Rename tables (exhaustive, mirrors _CLASS_RENAMES in the script)
+# ---------------------------------------------------------------------------
+
+_DIRECT_RENAMES = [
+ ("ForeachAdamW", "AdamW"),
+ ("ForeachNAdam", "NAdam"),
+ ("ForeachAdEMAMix", "AdEMAMix"),
+ ("ForeachAdamC", "AdamC"),
+ ("ForeachRMSprop", "RMSprop"),
+ ("ForeachSFAdamW", "SFAdamW"),
+ ("ForeachADOPT", "ADOPT"),
+ ("ForeachMuon", "Muon"),
+ ("ForeachLaProp", "LaProp"),
+ ("ForeachSOAP", "SOAP"),
+ ("ForeachSOAPNAdam", "SOAPNAdam"),
+ ("ForeachSOAPAdEMAMix", "SOAPAdEMAMix"),
+ ("ForeachSignLaProp", "SignLaProp"),
+ ("ForeachSOLP", "SOLP"),
+ ("ForeachPSGDKron", "PSGDKron"),
+ ("ForeachPSGDLRA", "PSGDLRA"),
+]
+
+_DELETED_RENAMES = [
+ ("PaLMForeachSFAdamW", "SFAdamW"),
+ ("PaLMForeachSOAP", "SOAP"),
+ ("PrecondScheduleForeachSOAP", "SOAP"),
+ ("PrecondSchedulePaLMForeachSOAP", "SOAP"),
+ ("ForeachPurePSGD", "PSGDKron"),
+ ("ForeachCachedPSGDKron", "PSGDKron"),
+ ("ForeachCachedDelayedPSGDKron", "PSGDKron"),
+ ("ForeachDelayedPSGD", "PSGDKron"),
+ ("ForeachCachedNewtonPSGD", "PSGDKron"),
+ ("NewtonHybrid2PSGDKron", "PSGDKron"),
+ ("ForeachDelayedPSGDLRA", "PSGDLRA"),
+ ("ForeachNewtonPSGDLRA", "PSGDLRA"),
+ ("NewtonHybrid2PSGDLRA", "PSGDLRA"),
+]
+
+_ALL_RENAMES = _DIRECT_RENAMES + _DELETED_RENAMES
+
+# PSGDLRA uses a different constructor signature (beta not betas, rank param, etc.)
+# so e2e tests using AdamW-style param_groups must exclude it
+_LRA_TARGETS = {"PSGDLRA"}
+_DIRECT_RENAMES_NO_LRA = [(o, n) for o, n in _DIRECT_RENAMES if n not in _LRA_TARGETS]
+_DELETED_RENAMES_NO_LRA = [(o, n) for o, n in _DELETED_RENAMES if n not in _LRA_TARGETS]
+_DIRECT_RENAMES_LRA = [(o, n) for o, n in _DIRECT_RENAMES if n in _LRA_TARGETS]
+_DELETED_RENAMES_LRA = [(o, n) for o, n in _DELETED_RENAMES if n in _LRA_TARGETS]
+
+_PASSTHROUGH = ["AdamW", "NAdam", "PSGDKron", "SGD", "SOAP", "Muon", "ADOPT", "LaProp", "PSGDLRA", "SFAdamW"]
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _flat_state(*shapes):
+ return {
+ i: {
+ "update_by_adam_exp_avg": torch.ones(s),
+ "update_by_adam_exp_avg_sq": torch.full(s, 2.0),
+ "is_initialized": [0],
+ }
+ for i, s in enumerate(shapes)
+ }
+
+
+def _nested_state(*shapes):
+ return {i: {0: {"key_0": torch.zeros(s), "is_initialized": [0]}} for i, s in enumerate(shapes)}
+
+
+def _v1_group(pids):
+ return {
+ "params": pids,
+ "lr": 0.0025,
+ "betas": [0.9, 0.99],
+ "eps": 1e-8,
+ "weight_decay": 0.0,
+ "warmup_steps": 0,
+ "foreach": True,
+ "stochastic_schedule": False,
+ "storage_dtype": "float32",
+ "mars": False,
+ "caution": False,
+ "mars_gamma": 0.0025,
+ "gradient_clipping": "use_default",
+ "update_clipping": "use_default",
+ "palm": "use_default",
+ "beta2_scale": 0.8,
+ }
+
+
+def _v2_group(pids):
+ return {**_v1_group(pids), "__class__": "heavyball.ForeachAdamW"}
+
+
+def _v3_group(pids):
+ g = _v1_group(pids)
+ g.pop("foreach")
+ g.pop("stochastic_schedule")
+ g["multi_tensor"] = True
+ return g
+
+
+def _v2_meta():
+ return {
+ "inner_group": {"stochastic_schedule": None},
+ "stochastic_schedule": None,
+ "precond_rng": pickle.dumps(random.Random(0x12312)),
+ "use_ema": False,
+ }
+
+
+def _v3_meta():
+ return {"inner_group": {}, "use_ema": False}
+
+
+def _load_heavyball_fresh():
+ package_root = pathlib.Path(__file__).resolve().parents[1]
+ heavyball_pkg = package_root / "heavyball"
+ saved = {n: sys.modules[n] for n in list(sys.modules) if n == "heavyball" or n.startswith("heavyball.")}
+ for n in list(sys.modules):
+ if n == "heavyball" or n.startswith("heavyball."):
+ sys.modules.pop(n)
+ spec = importlib.util.spec_from_file_location(
+ "heavyball", heavyball_pkg / "__init__.py", submodule_search_locations=[str(heavyball_pkg)]
+ )
+ mod = importlib.util.module_from_spec(spec)
+ sys.modules["heavyball"] = mod
+ spec.loader.exec_module(mod)
+ return saved
+
+
+def _restore_heavyball(saved):
+ for n in list(sys.modules):
+ if n == "heavyball" or n.startswith("heavyball."):
+ sys.modules.pop(n)
+ sys.modules.update(saved)
+
@pytest.fixture()
def runner():
return CliRunner()
-def test_cli_dry_run_updates_state(monkeypatch, runner, tmp_path):
- checkpoint_path = tmp_path / "ckpt.pt"
- checkpoint_path.touch()
+# ====================================================================
+# _resolve_class_name
+# ====================================================================
- state_container = {"state": {"initial": True}, "param_groups": ["group"]}
- checkpoint = {"optimizer": state_container}
- def fake_load(path, map_location=None):
- return checkpoint
+@pytest.mark.parametrize("old,new", _ALL_RENAMES)
+def test_resolve_class_name_renames(old, new):
+ assert migrate_script._resolve_class_name(old) == new
- def fake_migrate(state, _):
- return {"state": {"migrated": True}, "param_groups": []}
- def fail_save(*args, **kwargs):
- pytest.fail("torch.save should not run during dry-run")
+@pytest.mark.parametrize("name", _PASSTHROUGH)
+def test_resolve_class_name_passthrough(name):
+ assert migrate_script._resolve_class_name(name) == name
- monkeypatch.setattr(migrate_script.torch, "load", fake_load)
- monkeypatch.setattr(migrate_script, "migrate_state_dict", fake_migrate)
- monkeypatch.setattr(migrate_script.torch, "save", fail_save)
- result = runner.invoke(
- migrate_script.app,
- [str(checkpoint_path), "heavyball.Mock", "--state-key", "optimizer", "--dry-run"],
- )
+# ====================================================================
+# _load_optimizer_class
+# ====================================================================
+
+
+@pytest.mark.parametrize("old,new", _ALL_RENAMES)
+def test_load_optimizer_class_old_names(old, new):
+ import heavyball
+
+ assert migrate_script._load_optimizer_class(f"heavyball.{old}") is getattr(heavyball, new)
+
+
+@pytest.mark.parametrize("name", _PASSTHROUGH)
+def test_load_optimizer_class_current_names(name):
+ import heavyball
+
+ assert migrate_script._load_optimizer_class(name) is getattr(heavyball, name)
+
+
+def test_load_optimizer_class_bare_name():
+ import heavyball
+
+ assert migrate_script._load_optimizer_class("AdamW") is heavyball.AdamW
+
+
+def test_load_optimizer_class_invalid_raises():
+ with pytest.raises(ValueError, match="not found"):
+ migrate_script._load_optimizer_class("heavyball.TotallyFakeOptimizer")
+
+
+# ====================================================================
+# _detect_version
+# ====================================================================
+
+_DETECT_CASES = [
+ ("flat_foreach", {"state": _flat_state((2,)), "param_groups": [{"foreach": True}]}, 1),
+ ("flat_multi_tensor", {"state": _flat_state((2,)), "param_groups": [{"multi_tensor": True}]}, 1),
+ ("flat_neither", {"state": _flat_state((2,)), "param_groups": [{}]}, 1),
+ ("flat_multi_group", {"state": _flat_state((2,), (3,)), "param_groups": [{"foreach": True}, {}]}, 1),
+ ("nested_foreach", {"state": _nested_state((2,)), "param_groups": [{"foreach": True}]}, 2),
+ (
+ "nested_foreach_multi_group",
+ {"state": _nested_state((2,), (3,)), "param_groups": [{"foreach": True}, {"foreach": True}]},
+ 2,
+ ),
+ ("nested_multi_tensor", {"state": _nested_state((2,)), "param_groups": [{"multi_tensor": True}]}, 3),
+ (
+ "nested_multi_tensor_multi_group",
+ {"state": _nested_state((2,), (3,)), "param_groups": [{"multi_tensor": True}, {}]},
+ 3,
+ ),
+ ("empty_state_multi_tensor", {"state": {}, "param_groups": [{"multi_tensor": True}]}, 3),
+ ("empty_state_foreach", {"state": {}, "param_groups": [{"foreach": True}]}, 1),
+ ("empty_state_bare", {"state": {}, "param_groups": [{}]}, 1),
+ ("empty_state_empty_groups", {"state": {}, "param_groups": []}, 1),
+]
+
+
+@pytest.mark.parametrize("name,sd,expected", _DETECT_CASES, ids=[c[0] for c in _DETECT_CASES])
+def test_detect_version(name, sd, expected):
+ assert migrate_script._detect_version(sd) == expected
+
+
+# ====================================================================
+# _normalise_group_options
+# ====================================================================
+
+_NORM_CASES = [
+ ("foreach_renamed", {"foreach": True, "lr": 0.01}, {"multi_tensor": True, "lr": 0.01}),
+ ("foreach_false", {"foreach": False, "lr": 0.01}, {"multi_tensor": False, "lr": 0.01}),
+ ("multi_tensor_passthrough", {"multi_tensor": False, "lr": 0.01}, {"multi_tensor": False, "lr": 0.01}),
+ ("stochastic_stripped", {"foreach": True, "stochastic_schedule": False}, {"multi_tensor": True}),
+ ("betas_tuple", {"betas": [0.9, 0.999]}, {"betas": (0.9, 0.999)}),
+ ("weight_decay_steps_tuple", {"weight_decay_steps": [100, 200]}, {"weight_decay_steps": (100, 200)}),
+ ("params_stripped", {"params": [0, 1], "lr": 0.01}, {"lr": 0.01}),
+ ("empty_group", {}, {}),
+ ("params_only", {"params": [0]}, {}),
+ ("all_removed", {"params": [0], "stochastic_schedule": None}, {}),
+ (
+ "kitchen_sink",
+ {"params": [0, 1], "foreach": True, "stochastic_schedule": True, "lr": 0.01, "betas": [0.9, 0.99]},
+ {"multi_tensor": True, "lr": 0.01, "betas": (0.9, 0.99)},
+ ),
+]
+
+
+@pytest.mark.parametrize("name,group,expected", _NORM_CASES, ids=[c[0] for c in _NORM_CASES])
+def test_normalise_group_options(name, group, expected):
+ assert migrate_script._normalise_group_options(group) == expected
+
+
+# ====================================================================
+# _ensure_set
+# ====================================================================
+
+
+@pytest.mark.parametrize(
+ "inp,expected",
+ [
+ ({1, 2}, {1, 2}),
+ ([3, 4], {3, 4}),
+ ((5,), {5}),
+ (None, set()),
+ (7, {7}),
+ ([], set()),
+ ],
+ ids=["set", "list", "tuple", "none", "scalar", "empty_list"],
+)
+def test_ensure_set(inp, expected):
+ assert migrate_script._ensure_set(inp) == expected
+
+
+# ====================================================================
+# _guess_tensor_meta
+# ====================================================================
+
+
+@pytest.mark.parametrize(
+ "entry,shape,dtype",
+ [
+ ({"a": torch.zeros(3, 4, dtype=torch.float16)}, (3, 4), torch.float16),
+ ({"a": [torch.zeros(2, dtype=torch.bfloat16)]}, (2,), torch.bfloat16),
+ ({"a": "not_a_tensor", "b": torch.ones(5)}, (5,), torch.float32),
+ ({}, (1,), torch.float32),
+ ({"a": 42}, (1,), torch.float32),
+ ],
+ ids=["tensor", "tensor_list", "mixed_finds_tensor", "empty", "no_tensors"],
+)
+def test_guess_tensor_meta(entry, shape, dtype):
+ s, d = migrate_script._guess_tensor_meta(entry)
+ assert s == shape
+ assert d == dtype
+
+
+# ====================================================================
+# _resolve_state_container
+# ====================================================================
+
+
+def test_resolve_state_container_single_key():
+ root = {"optimizer": {"state": {}, "param_groups": []}}
+ assert migrate_script._resolve_state_container(root, ["optimizer"]) is root["optimizer"]
+
+
+def test_resolve_state_container_nested_key():
+ inner = {"state": {}, "param_groups": []}
+ root = {"model": {"training": {"opt": inner}}}
+ assert migrate_script._resolve_state_container(root, ["model", "training", "opt"]) is inner
+
+
+def test_resolve_state_container_missing_key():
+ with pytest.raises(KeyError, match="not found"):
+ migrate_script._resolve_state_container({"a": {}}, ["a", "b"])
+
+
+def test_resolve_state_container_not_optimizer():
+ with pytest.raises(ValueError, match="not an optimizer state dict"):
+ migrate_script._resolve_state_container({"opt": {"just_data": 1}}, ["opt"])
+
+
+# ====================================================================
+# _migrate_v2_to_v3
+# ====================================================================
+
+
+@pytest.mark.parametrize(
+ "meta_keys",
+ [
+ {"stochastic_schedule": None, "precond_rng": pickle.dumps(random.Random(0))},
+ {"stochastic_schedule": None},
+ {"precond_rng": b"rng"},
+ {},
+ ],
+ ids=["both", "stochastic_only", "precond_rng_only", "clean"],
+)
+def test_migrate_v2_to_v3_strips_stale_meta(meta_keys):
+ sd = {
+ "state": _nested_state((3,)),
+ "param_groups": [{"params": [0], "foreach": True, "stochastic_schedule": False}],
+ "heavyball": {"inner_group": {"stochastic_schedule": None}, "use_ema": False, **meta_keys},
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ group = sd["param_groups"][0]
+ assert group["multi_tensor"] is True
+ assert "foreach" not in group
+ assert "stochastic_schedule" not in group
+ hb = sd["heavyball"]
+ for key in ("stochastic_schedule", "precond_rng"):
+ assert key not in hb
+ assert key not in hb["inner_group"]
+ assert hb["use_ema"] is False
+
+
+def test_migrate_v2_to_v3_multi_group():
+ sd = {
+ "state": _nested_state((2,), (3,)),
+ "param_groups": [
+ {"params": [0], "foreach": True, "stochastic_schedule": False},
+ {"params": [1], "foreach": False, "stochastic_schedule": True, "lr": 0.1},
+ ],
+ "heavyball": _v2_meta(),
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ assert sd["param_groups"][0]["multi_tensor"] is True
+ assert sd["param_groups"][1]["multi_tensor"] is False
+ assert sd["param_groups"][1]["lr"] == 0.1
+ for g in sd["param_groups"]:
+ assert "foreach" not in g
+ assert "stochastic_schedule" not in g
+
+
+def test_migrate_v2_to_v3_preserves_non_stale_meta():
+ sd = {
+ "state": _nested_state((2,)),
+ "param_groups": [{"params": [0], "foreach": True}],
+ "heavyball": {**_v2_meta(), "use_ema": True, "ema_decay": 0.005, "compile_step": True},
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ assert sd["heavyball"]["use_ema"] is True
+ assert sd["heavyball"]["ema_decay"] == 0.005
+ assert sd["heavyball"]["compile_step"] is True
+
+
+def test_migrate_v2_to_v3_no_heavyball_key():
+ sd = {
+ "state": _nested_state((2,)),
+ "param_groups": [{"params": [0], "foreach": True}],
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ assert sd["param_groups"][0]["multi_tensor"] is True
+
+
+# ====================================================================
+# _migrate_single_state
+# ====================================================================
+
+
+@pytest.mark.parametrize("n_views", [1, 2, 5])
+def test_migrate_single_state_rewrites_keys(n_views):
+ mappings = [
+ migrate_script.TransformMapping("exp_avg", "exp_avg_0", 0),
+ migrate_script.TransformMapping("exp_avg_sq", "exp_avg_sq_0", 0),
+ ]
+ entry = {
+ "exp_avg": [torch.ones(2)] * n_views if n_views > 1 else torch.ones(2),
+ "exp_avg_sq": [torch.full((2,), 2.0)] * n_views if n_views > 1 else torch.full((2,), 2.0),
+ "is_initialized": [0],
+ }
+ migrated = migrate_script._migrate_single_state(entry, mappings)
+ for view_idx in range(n_views):
+ bucket = migrated[view_idx]
+ assert "exp_avg_0" in bucket
+ assert "exp_avg_sq_0" in bucket
+ assert "exp_avg" not in bucket
+ assert "exp_avg_sq" not in bucket
+ assert 0 in bucket["is_initialized"]
+
+
+def test_migrate_single_state_empty_mappings():
+ entry = {"some_key": torch.ones(3), "is_initialized": {0, 1}}
+ migrated = migrate_script._migrate_single_state(entry, [])
+ assert "some_key" in migrated[0]
+ assert set(migrated[0]["is_initialized"]) == {0, 1}
+
+
+def test_migrate_single_state_multiple_transforms():
+ mappings = [
+ migrate_script.TransformMapping("exp_avg", "update_by_adam_exp_avg_0", 0),
+ migrate_script.TransformMapping("exp_avg_sq", "update_by_adam_exp_avg_sq_0", 0),
+ migrate_script.TransformMapping("momentum", "heavyball_momentum_momentum_1", 1),
+ ]
+ entry = {
+ "exp_avg": torch.ones(4),
+ "exp_avg_sq": torch.full((4,), 2.0),
+ "momentum": torch.full((4,), 0.5),
+ "is_initialized": [0, 1],
+ }
+ migrated = migrate_script._migrate_single_state(entry, mappings)
+ bucket = migrated[0]
+ assert "update_by_adam_exp_avg_0" in bucket
+ assert "update_by_adam_exp_avg_sq_0" in bucket
+ assert "heavyball_momentum_momentum_1" in bucket
+ assert set(bucket["is_initialized"]) == {0, 1}
+
+
+def test_migrate_single_state_preserves_values():
+ t = torch.arange(6, dtype=torch.float32)
+ mappings = [migrate_script.TransformMapping("old", "new_0", 0)]
+ migrated = migrate_script._migrate_single_state({"old": t, "is_initialized": [0]}, mappings)
+ assert torch.equal(migrated[0]["new_0"], t)
+
+# ====================================================================
+# Full migrate_state_dict
+# ====================================================================
+
+
+def test_migrate_v2_state_dict():
+ sd = {
+ "state": _nested_state((4,)),
+ "param_groups": [_v2_group([0])],
+ "heavyball": _v2_meta(),
+ }
+ migrated = migrate_script.migrate_state_dict(sd, "heavyball.ForeachAdamW")
+ group = migrated["param_groups"][0]
+ assert group["multi_tensor"] is True
+ assert "foreach" not in group
+ assert "stochastic_schedule" not in group
+ hb = migrated.get("heavyball", {})
+ assert "stochastic_schedule" not in hb
+ assert "precond_rng" not in hb
+
+
+def test_migrate_v3_is_noop():
+ sd = {
+ "state": _nested_state((4,)),
+ "param_groups": [_v3_group([0])],
+ "heavyball": _v3_meta(),
+ }
+ original_key0 = sd["state"][0][0]["key_0"].clone()
+ migrated = migrate_script.migrate_state_dict(sd, "heavyball.AdamW")
+ assert migrated is not sd
+ assert torch.equal(migrated["state"][0][0]["key_0"], original_key0)
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+
+
+def test_migrate_v2_with_old_class_name():
+ sd = {
+ "state": _nested_state((4,)),
+ "param_groups": [_v2_group([0])],
+ "heavyball": _v2_meta(),
+ }
+ migrated = migrate_script.migrate_state_dict(sd, "ForeachAdamW")
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+
+
+def test_migrate_v2_multi_param():
+ sd = {
+ "state": _nested_state((4, 4), (8,), (2, 2, 2)),
+ "param_groups": [_v2_group([0, 1, 2])],
+ "heavyball": _v2_meta(),
+ }
+ migrated = migrate_script.migrate_state_dict(sd, "AdamW")
+ for pid in (0, 1, 2):
+ assert pid in migrated["state"]
+ assert 0 in migrated["state"][pid]
+
+
+# ====================================================================
+# CLI (typer) | mocked
+# ====================================================================
+
+
+def test_cli_dry_run(monkeypatch, runner, tmp_path):
+ ckpt = tmp_path / "ckpt.pt"
+ ckpt.touch()
+ container = {"state": {}, "param_groups": []}
+ checkpoint = {"optimizer": container}
+
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: checkpoint)
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {"ok": True}, "param_groups": []})
+ monkeypatch.setattr(migrate_script.torch, "save", lambda *a, **kw: pytest.fail("save during dry-run"))
+
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--dry-run"])
assert result.exit_code == 0
assert "Dry run" in result.stdout
- assert state_container == {"state": {"migrated": True}, "param_groups": []}
+ assert container["state"] == {"ok": True}
def test_cli_writes_output(monkeypatch, runner, tmp_path):
- checkpoint_path = tmp_path / "source.pt"
- checkpoint_path.touch()
- output_path = tmp_path / "out.pt"
-
- state_container = {"state": {"initial": True}, "param_groups": ["group"]}
- checkpoint = {"optimizer": state_container}
- migrated = {"state": {"migrated": True}, "param_groups": []}
+ ckpt = tmp_path / "source.pt"
+ ckpt.touch()
+ out = tmp_path / "out.pt"
+ migrated = {"state": {"done": True}, "param_groups": []}
+ checkpoint = {"optimizer": {"state": {}, "param_groups": []}}
+ saved = {}
- def fake_load(path, map_location=None):
- return checkpoint
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: checkpoint)
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: migrated)
+ monkeypatch.setattr(migrate_script.torch, "save", lambda obj, path: saved.update(obj=obj, path=pathlib.Path(path)))
- def fake_migrate(state, _):
- return migrated
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--output", str(out)])
+ assert result.exit_code == 0
+ assert saved["path"] == out
+ assert saved["obj"]["optimizer"] == migrated
- monkeypatch.setattr(migrate_script.torch, "load", fake_load)
- monkeypatch.setattr(migrate_script, "migrate_state_dict", fake_migrate)
+def test_cli_overwrites_input_by_default(monkeypatch, runner, tmp_path):
+ ckpt = tmp_path / "ckpt.pt"
+ ckpt.touch()
saved = {}
- def fake_save(obj, path):
- saved["obj"] = obj
- saved["path"] = pathlib.Path(path)
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: {"optimizer": {"state": {}, "param_groups": []}})
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {}, "param_groups": []})
+ monkeypatch.setattr(migrate_script.torch, "save", lambda obj, path: saved.update(path=pathlib.Path(path)))
+
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock"])
+ assert result.exit_code == 0
+ assert saved["path"] == ckpt
- monkeypatch.setattr(migrate_script.torch, "save", fake_save)
- result = runner.invoke(
- migrate_script.app,
- [str(checkpoint_path), "heavyball.Mock", "--output", str(output_path)],
- )
+@pytest.mark.parametrize("key", ["optimizer", "model.optimizer", "a.b.c"])
+def test_cli_state_key_parsing(monkeypatch, runner, tmp_path, key):
+ ckpt = tmp_path / "ckpt.pt"
+ ckpt.touch()
+ parts = key.split(".")
+ inner = {"state": {}, "param_groups": []}
+ root = inner
+ for p in reversed(parts):
+ root = {p: root}
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: root)
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {}, "param_groups": []})
+ monkeypatch.setattr(migrate_script.torch, "save", lambda *a, **kw: None)
+
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--state-key", key])
assert result.exit_code == 0
- assert saved["path"] == output_path
- assert saved["obj"]["optimizer"] == migrated
-def test_cli_migrates_legacy_checkpoint(runner, tmp_path):
- package_root = pathlib.Path(__file__).resolve().parents[1]
- heavyball_pkg = package_root / "heavyball"
- saved_heavyball_modules = {
- name: sys.modules[name] for name in list(sys.modules) if name == "heavyball" or name.startswith("heavyball.")
- }
- for name in list(sys.modules):
- if name == "heavyball" or name.startswith("heavyball."):
- sys.modules.pop(name)
+# ====================================================================
+# CLI - real end-to-end (fresh heavyball import)
+# ====================================================================
- spec = importlib.util.spec_from_file_location(
- "heavyball",
- heavyball_pkg / "__init__.py",
- submodule_search_locations=[str(heavyball_pkg)],
- )
- heavyball_module = importlib.util.module_from_spec(spec)
- sys.modules["heavyball"] = heavyball_module
- spec.loader.exec_module(heavyball_module)
- try:
- checkpoint_path = tmp_path / "legacy.pt"
- output_path = tmp_path / "migrated.pt"
-
- legacy_state = {
- "state": {
- 0: {
- "update_by_adam_exp_avg": torch.ones((2, 2), dtype=torch.float32),
- "update_by_adam_exp_avg_sq": torch.full((2, 2), 2.0, dtype=torch.float32),
- "is_initialized": [0],
- },
- 1: {
- "update_by_adam_exp_avg": torch.ones((2,), dtype=torch.float32),
- "update_by_adam_exp_avg_sq": torch.full((2,), 2.0, dtype=torch.float32),
- "is_initialized": [0],
- },
- },
- "param_groups": [
- {
- "params": [0, 1],
- "lr": 0.0025,
- "betas": [0.9, 0.99],
- "eps": 1e-8,
- "weight_decay": 0.0,
- "warmup_steps": 0,
- "foreach": True,
- "storage_dtype": "float32",
- "mars": False,
- "caution": False,
- "mars_gamma": 0.0025,
- "gradient_clipping": "use_default",
- "update_clipping": "use_default",
- "palm": "use_default",
- "beta2_scale": 0.8,
- "__class__": "heavyball.ForeachAdamW",
- }
- ],
- }
+def _make_v1_checkpoint(path, shapes, *, group_overrides=None):
+ g = _v1_group(list(range(len(shapes))))
+ g["multi_tensor"] = True
+ g.pop("foreach")
+ g.pop("stochastic_schedule")
+ if group_overrides:
+ g.update(group_overrides)
+ torch.save({"optimizer": {"state": _flat_state(*shapes), "param_groups": [g]}}, path)
- torch.save({"optimizer": legacy_state}, checkpoint_path)
- result = runner.invoke(
- migrate_script.app,
- [str(checkpoint_path), "heavyball.ForeachAdamW", "--output", str(output_path)],
- )
+def _make_v2_checkpoint(path, shapes, *, group_overrides=None):
+ g = _v2_group(list(range(len(shapes))))
+ if group_overrides:
+ g.update(group_overrides)
+ torch.save({"optimizer": {"state": _nested_state(*shapes), "param_groups": [g], "heavyball": _v2_meta()}}, path)
+
+def _make_v3_checkpoint(path, shapes, *, group_overrides=None):
+ g = _v3_group(list(range(len(shapes))))
+ if group_overrides:
+ g.update(group_overrides)
+ torch.save({"optimizer": {"state": _nested_state(*shapes), "param_groups": [g], "heavyball": _v3_meta()}}, path)
+
+
+def _run_cli_e2e(runner, ckpt, out, class_name):
+ saved = _load_heavyball_fresh()
+ try:
+ result = runner.invoke(migrate_script.app, [str(ckpt), class_name, "--output", str(out)])
assert result.exit_code == 0, result.stderr or result.stdout
- assert output_path.exists()
-
- migrated = torch.load(output_path)
- migrated_state = migrated["optimizer"]
-
- for pid, shape in [(0, (2, 2)), (1, (2,))]:
- assert pid in migrated_state["state"], f"missing state for parameter {pid}"
- migrated_bucket = migrated_state["state"][pid]
- assert 0 in migrated_bucket, f"missing transformed view for parameter {pid}"
- view_state = migrated_bucket[0]
- assert view_state["update_by_adam_exp_avg_0"].shape == shape
- assert view_state["update_by_adam_exp_avg_sq_0"].shape == shape
- assert torch.allclose(view_state["update_by_adam_exp_avg_0"], torch.ones(shape))
- assert torch.allclose(view_state["update_by_adam_exp_avg_sq_0"], torch.full(shape, 2.0))
- assert view_state["is_initialized"] == [0]
-
- assert "heavyball" in migrated_state
- assert migrated_state["heavyball"]["inner_group"]["stochastic_schedule"] is None
- assert "Migrated checkpoint written to" in result.stdout
+ return torch.load(out)["optimizer"]
finally:
- for name in list(sys.modules):
- if name == "heavyball" or name.startswith("heavyball."):
- sys.modules.pop(name)
- sys.modules.update(saved_heavyball_modules)
+ _restore_heavyball(saved)
+
+
+# --- v1 end-to-end ---
+
+
+def test_e2e_v1_adamw(runner, tmp_path):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [(2, 2), (2,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ for pid, shape in [(0, (2, 2)), (1, (2,))]:
+ view = migrated["state"][pid][0]
+ assert view["update_by_adam_exp_avg_0"].shape == shape
+ assert view["update_by_adam_exp_avg_sq_0"].shape == shape
+ assert torch.allclose(view["update_by_adam_exp_avg_0"], torch.ones(shape))
+ assert view["is_initialized"] == [0]
+ assert "heavyball" in migrated
+
+
+@pytest.mark.parametrize("old_name,new_name", _DIRECT_RENAMES_NO_LRA)
+def test_e2e_v1_all_direct_renames(runner, tmp_path, old_name, new_name):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}")
+ assert 0 in migrated["state"]
+ assert "heavyball" in migrated
+
+
+@pytest.mark.parametrize("old_name,new_name", _DELETED_RENAMES_NO_LRA)
+def test_e2e_v1_all_deleted_renames(runner, tmp_path, old_name, new_name):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}")
+ assert 0 in migrated["state"]
+ assert "heavyball" in migrated
+
+
+# PSGDLRA v1 e2e: PSGDLRA computes rank from param numel during __init__,
+# which the migration script's _instantiate_optimizer can't replicate from
+# a state dict alone (rank isn't stored in param_groups). The rename
+# resolution and migration logic for LRA is covered by the unit tests above.
+
+
+@pytest.mark.parametrize("n_params", [1, 3, 5])
+def test_e2e_v1_varying_param_count(runner, tmp_path, n_params):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ shapes = [(i + 2,) for i in range(n_params)]
+ _make_v1_checkpoint(ckpt, shapes)
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ for pid in range(n_params):
+ assert pid in migrated["state"]
+
+
+@pytest.mark.parametrize("shape", [(4,), (2, 3), (2, 3, 4)])
+def test_e2e_v1_varying_shapes(runner, tmp_path, shape):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [shape])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ assert migrated["state"][0][0]["update_by_adam_exp_avg_0"].shape == shape
+
+
+# --- v2 end-to-end ---
+
+
+def test_e2e_v2_basic(runner, tmp_path):
+ ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt"
+ _make_v2_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ group = migrated["param_groups"][0]
+ assert group["multi_tensor"] is True
+ assert "foreach" not in group
+ assert "stochastic_schedule" not in group
+ hb = migrated["heavyball"]
+ assert "precond_rng" not in hb
+ assert "stochastic_schedule" not in hb
+
+
+@pytest.mark.parametrize("old_name,new_name", _DIRECT_RENAMES_NO_LRA)
+def test_e2e_v2_all_direct_renames(runner, tmp_path, old_name, new_name):
+ ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt"
+ _make_v2_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}")
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+ assert "foreach" not in migrated["param_groups"][0]
+
+
+@pytest.mark.parametrize("n_params", [1, 3, 5])
+def test_e2e_v2_varying_param_count(runner, tmp_path, n_params):
+ ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt"
+ shapes = [(i + 2,) for i in range(n_params)]
+ _make_v2_checkpoint(ckpt, shapes)
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ for pid in range(n_params):
+ assert pid in migrated["state"]
+
+
+# --- v3 end-to-end (noop) ---
+
+
+def test_e2e_v3_noop(runner, tmp_path):
+ ckpt, out = tmp_path / "v3.pt", tmp_path / "out.pt"
+ _make_v3_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+ assert torch.equal(migrated["state"][0][0]["key_0"], torch.zeros(4))
diff --git a/test/test_optimizer_cpu_smoke.py b/test/test_optimizer_cpu_smoke.py
index 9a9fa68..7875d5f 100644
--- a/test/test_optimizer_cpu_smoke.py
+++ b/test/test_optimizer_cpu_smoke.py
@@ -1,9 +1,11 @@
+import functools
import inspect
import pytest
import torch
import heavyball
+from heavyball import chainable as C
from heavyball.utils import StatefulOptimizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -31,7 +33,7 @@ def _optimizer_params():
)
)
continue
- if name == "ForeachSOAPNAdam":
+ if name == "SOAPNAdam":
params.append(
pytest.param(
name,
@@ -80,3 +82,44 @@ def closure():
clone = opt_cls(model.parameters())
clone.load_state_dict(state_dict)
assert clone.state_dict()["state"].keys() == state_dict["state"].keys()
+
+
+def test_optimizer_keeps_constructor_compatibility_features():
+ param = torch.nn.Parameter(torch.randn(4, 4, device=DEVICE))
+
+ with pytest.warns(FutureWarning, match="renamed to 'multi_tensor'"):
+ optimizer = heavyball.AdamW([param], foreach=True)
+ assert optimizer.param_groups[0]["multi_tensor"] is True
+
+ with pytest.raises(TypeError, match="Removed in HeavyBall"):
+ heavyball.SOAP([param], normalize_grads=True)
+
+ with pytest.warns(UserWarning, match="Working with uncaptured keyword arguments"):
+ heavyball.AdamW([param], totally_fake=True)
+
+
+def test_optimizer_accepts_explicit_orig_shapes():
+ param = torch.nn.Parameter(torch.randn(4, 4, device=DEVICE))
+ shapes = heavyball.capture_param_shapes([param])
+ optimizer = heavyball.AdamW([param], orig_shapes=shapes)
+ assert "orig_shapes" not in optimizer.param_groups[0]
+
+
+def test_subclass_defaults_still_apply():
+ class ScheduledSOAP(heavyball.SOAP):
+ use_precond_schedule = True
+
+ class DelayedPSGDKron(heavyball.PSGDKron):
+ delayed = True
+ exp_avg_input = False
+
+ param = torch.nn.Parameter(torch.randn(4, 4, device=DEVICE))
+
+ soap = ScheduledSOAP([param])
+ assert "precondition_frequency" not in soap.param_groups[0]
+ assert "precond_scheduler" not in soap.param_groups[0]
+
+ psgd = DelayedPSGDKron([param])
+ first_fn = psgd.fns[0]
+ assert isinstance(first_fn, functools.partial)
+ assert first_fn.func.get_fn() is C.scale_by_delayed_psgd.get_fn()
diff --git a/test/test_param_ecc_compile.py b/test/test_param_ecc_compile.py
index 7182344..6b8c462 100644
--- a/test/test_param_ecc_compile.py
+++ b/test/test_param_ecc_compile.py
@@ -70,7 +70,7 @@ def _train_linear(compile_mode, rne=False, steps=50):
try:
with ctx:
model = torch.nn.Linear(64, 32, bias=False, device="cuda")
- opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, param_ecc="bf16+8")
+ opt = heavyball.AdamW(model.parameters(), lr=1e-2, param_ecc="bf16+8")
data = torch.randn(32, 64, device="cuda")
target = torch.randn(32, 32, device="cuda")
for _ in range(steps):
@@ -108,7 +108,7 @@ def test_ecc_populated_rne():
p, ecc, _ = _train_linear("max-autotune-no-cudagraphs", rne=True)
assert ecc is not None, "param::ecc not found in optimizer state"
assert ecc.any(), (
- "param::ecc all zeros with RNE rounding under compile — Inductor likely folded the f32->bf16->f32 roundtrip"
+ "param::ecc all zeros with RNE rounding under compile, Inductor likely folded the f32->bf16->f32 roundtrip"
)
clean()
diff --git a/test/test_stochastic_updates.py b/test/test_stochastic_updates.py
deleted file mode 100644
index bd3f409..0000000
--- a/test/test_stochastic_updates.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import pytest
-import torch
-from torch import nn
-from utils import REPRESENTATIVE_OPTS
-
-import heavyball
-from heavyball.utils import clean, set_torch
-
-_SAVED_COMPILE_MODE = heavyball.utils.compile_mode
-heavyball.utils.compile_mode = "default"
-
-
-@pytest.fixture(autouse=True)
-def _isolate_compile_mode():
- heavyball.utils.compile_mode = "default"
- yield
- heavyball.utils.compile_mode = _SAVED_COMPILE_MODE
-
-
-PSGD_OPTS = [o for o in REPRESENTATIVE_OPTS if "PSGD" in o]
-
-
-@pytest.mark.parametrize("opt", PSGD_OPTS)
-def test_foreach(opt, size: int = 128, depth: int = 1, iterations: int = 512, outer_iterations: int = 2):
- set_torch()
-
- opt = getattr(heavyball, opt)
-
- losses = []
-
- for stochastic in [False, True]:
- print("stochastic", stochastic)
- torch.manual_seed(0x2131290)
- losses.append([])
-
- for i in range(outer_iterations):
- model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda()
- o = opt(
- model.parameters(),
- lr=1e-3,
- stochastic_schedule=stochastic,
- preconditioner_update_probability=lambda step: 0.1,
- )
-
- for _ in range(iterations):
- loss = model(torch.randn((128, size), device="cuda")).square().mean()
- loss.backward()
- o.step()
- o.zero_grad()
- losses[-1].append(loss.detach())
-
- del model, o
- clean()
-
- stochastic = sum([l.item() for l in losses[1]])
- deterministic = sum([l.item() for l in losses[0]])
- print(f"{deterministic=}, {stochastic=}")
- assert not torch.isclose(torch.tensor(deterministic), torch.tensor(stochastic), rtol=0.01)
diff --git a/test/test_stochastic_utils_cpu.py b/test/test_stochastic_utils_cpu.py
index 0b16f2d..693533d 100644
--- a/test/test_stochastic_utils_cpu.py
+++ b/test/test_stochastic_utils_cpu.py
@@ -1,8 +1,6 @@
-import os
-
-os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
-
+import pytest
import torch
+from torch._dynamo import config
import heavyball
from heavyball.utils import (
@@ -13,7 +11,15 @@
stochastic_divide_with_eps_,
)
-heavyball.utils.atan2_scale = 1024.0
+config.cache_size_limit = 128
+
+
+@pytest.fixture(autouse=True)
+def _restore_atan2_scale():
+ orig = heavyball.utils.atan2_scale
+ heavyball.utils.atan2_scale = 1024.0
+ yield
+ heavyball.utils.atan2_scale = orig
def _average_stochastic_round(source: torch.Tensor, trials: int = 512) -> torch.Tensor:
diff --git a/test/test_toy_training.py b/test/test_toy_training.py
index 08859c5..6a729af 100644
--- a/test/test_toy_training.py
+++ b/test/test_toy_training.py
@@ -23,16 +23,14 @@ def _flatten_tensors(tensors: Iterable[torch.Tensor]):
EXTRA_KWARGS = {
- "ForeachAdamC": {"max_lr": 0.0025},
+ "AdamC": {"max_lr": 0.0025},
}
def _optimizer_params():
params = []
- for name in sorted(dir(heavyball)):
- if name.startswith("_"):
- continue
- attr = getattr(heavyball, name)
+ for name in sorted(heavyball.__all__):
+ attr = getattr(heavyball, name, None)
if not isinstance(attr, type) or not issubclass(attr, optim.Optimizer):
continue
if attr is optim.Optimizer:
@@ -62,15 +60,15 @@ def toy_training_results(request):
sig = inspect.signature(optimizer_cls.__init__)
kwargs = dict(EXTRA_KWARGS.get(optimizer_name, {}))
- if "foreach" in sig.parameters:
- kwargs["foreach"] = True
+ if "multi_tensor" in sig.parameters:
+ kwargs["multi_tensor"] = True
if optimizer_name == "SAMWrapper":
inner_kwargs = {}
- inner_sig = inspect.signature(heavyball.ForeachAdamW.__init__)
- if "foreach" in inner_sig.parameters:
- inner_kwargs["foreach"] = True
- inner_optimizer = heavyball.ForeachAdamW(param_list, **inner_kwargs)
+ inner_sig = inspect.signature(heavyball.AdamW.__init__)
+ if "multi_tensor" in inner_sig.parameters:
+ inner_kwargs["multi_tensor"] = True
+ inner_optimizer = heavyball.AdamW(param_list, **inner_kwargs)
optimizer = optimizer_cls(param_list, wrapped_optimizer=inner_optimizer, **kwargs)
else:
optimizer = optimizer_cls(param_list, **kwargs)
diff --git a/test/test_utils_cpu.py b/test/test_utils_cpu.py
index 54786a4..b9af3d7 100644
--- a/test/test_utils_cpu.py
+++ b/test/test_utils_cpu.py
@@ -1,5 +1,4 @@
import os
-import random
import warnings
from copy import deepcopy
@@ -40,6 +39,9 @@
# Ensure Torch dynamo stays disabled on CI runners without GPU support.
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
+from torch._dynamo import config
+
+config.cache_size_limit = 128
_SAVED_COMPILE_MODE = heavyball.utils.compile_mode
heavyball.utils.compile_mode = None
@@ -173,9 +175,9 @@ def test_triu_line_roundtrip_on_cpu():
torch.arange(9, dtype=torch.float32).reshape(3, 3),
]
packed = triu_to_line(tensors)
- restored = line_to_triu(packed, symmetric_output=True)
+ restored = line_to_triu(packed)
for original, rebuilt in zip(tensors, restored, strict=True):
- assert torch.allclose(rebuilt, torch.triu(original) + torch.triu(original, diagonal=1).T)
+ assert torch.allclose(rebuilt, torch.triu(original))
def test_warn_once_only_emits_single_warning(monkeypatch):
@@ -192,7 +194,7 @@ def test_warn_once_only_emits_single_warning(monkeypatch):
def test_psgd_should_update_accumulates_probability():
- group = {"stochastic_schedule": False}
+ group = {}
outcomes = [psgd_should_update(group, 0.4) for _ in range(4)]
assert outcomes[:2] == [False, False]
assert outcomes[2] is True
@@ -200,15 +202,6 @@ def test_psgd_should_update_accumulates_probability():
assert group["cumulative_prob_prob_step"] == 4
-def test_psgd_should_update_stochastic_schedule_uses_rng():
- rng = random.Random(123)
- group = {"stochastic_schedule": True}
- calls = [psgd_should_update(group, 0.5, rng=rng) for _ in range(5)]
- rng = random.Random(123)
- expected = [rng.random() < 0.5 for _ in range(5)]
- assert calls == expected
-
-
def test_stochastic_math_helpers_match_expected_results(n=1024):
torch.manual_seed(0x172893)
a = torch.arange(n).float()
@@ -226,7 +219,8 @@ def test_stochastic_math_helpers_match_expected_results(n=1024):
stochastic_add_divide_(c, b, alpha=1.0, divisor=2.0)
assert torch.allclose(c.float(), (a + b * 1) / 2)
- orig = heavyball.utils.default_division_backend
+ orig_backend = heavyball.utils.default_division_backend
+ orig_scale = heavyball.utils.atan2_scale
try:
heavyball.utils.atan2_scale = 1024
for backend in heavyball.utils.DivisionBackend:
@@ -235,7 +229,8 @@ def test_stochastic_math_helpers_match_expected_results(n=1024):
stochastic_divide_with_eps_(c, b)
assert torch.allclose(c.float(), a / b), f"Backend {backend} failed"
finally:
- heavyball.utils.default_division_backend = orig
+ heavyball.utils.default_division_backend = orig_backend
+ heavyball.utils.atan2_scale = orig_scale
def test_stochastic_math_accuracy():
diff --git a/test/test_utils_property.py b/test/test_utils_property.py
index d0c25cb..ec6c3e5 100644
--- a/test/test_utils_property.py
+++ b/test/test_utils_property.py
@@ -22,6 +22,9 @@
# Ensure torch.compile stays disabled on CPU-only CI runners.
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
+from torch._dynamo import config
+
+config.cache_size_limit = 128
heavyball.utils.compile_mode = None
diff --git a/test/utils.py b/test/utils.py
index 800c2fc..b00ff8c 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -13,17 +13,10 @@
# AdEMAMix variants require 3 betas, SplitOpt requires dict param specs,
# SAMWrapper and Newton variants require closures
_SKIP_GET_OPTIM = {
- "ForeachAdEMAMix",
- "ForeachSOAPAdEMAMix",
+ "AdEMAMix",
"SOAPAdEMAMix",
"SplitOpt",
"SAMWrapper",
- "ForeachCachedNewtonPSGD",
- "NewtonHybrid2PSGDKron",
- "ForeachNewtonPSGDLRA",
- "NewtonHybrid2PSGDLRA",
- "NewtonPSGDLRA",
- "NewtonPSGDKron",
}
@@ -40,8 +33,8 @@ def _fn_key(f):
def _deduplicate_by_chain(names):
"""Keep one optimizer per unique chain of functions.
- Two optimizers that differ only by foreach=True/False have identical
- chains and test the same code paths — keep whichever appears first.
+ Two optimizers that differ only by multi_tensor=True/False have identical
+ chains and test the same code paths, keep whichever appears first.
"""
seen = set()
out = []