From 318cf5f11287d18d5391673046950696777f1936 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Sat, 28 Mar 2026 15:08:22 +0000
Subject: [PATCH 01/24] add PSGD PRO, fix existing PSGD
---
README.md | 2 +-
heavyball/__init__.py | 76 +++++++++++++++++++++++++
heavyball/chainable.py | 116 +++++++++++++++++++++++++++++++++++++++
heavyball/utils.py | 122 +++++++++++++++++++++++++++++++++--------
4 files changed, 291 insertions(+), 25 deletions(-)
diff --git a/README.md b/README.md
index e45eefe..d9ba71f 100644
--- a/README.md
+++ b/README.md
@@ -70,7 +70,7 @@ Muon, MuonLaProp, OrthoLaProp, LaPropOrtho
SOAP, PaLMSOAP, PrecondScheduleSOAP, PrecondSchedulePaLMSOAP, SOAPNAdam, SOAPAdEMAMix, ForeachSOLP
**PSGD (Kronecker):**
-PSGDKron, CachedPSGDKron, DelayedPSGD, CachedDelayedPSGDKron, PurePSGD, NewtonPSGDKron, NewtonHybrid2PSGDKron
+PSGDPRO, PSGDKron, CachedPSGDKron, DelayedPSGD, CachedDelayedPSGDKron, PurePSGD, NewtonPSGDKron, NewtonHybrid2PSGDKron
`Newton`-PSGD requires a closure passed to `step()`.
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index e8c3eaf..784cfc1 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -1023,6 +1023,82 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
hvp_interval = 2
+class PSGDPRO(C.BaseOpt):
+ """
+ PSGD with Q0.5EQ1.5 (PRO/Procrustes) preconditioner update.
+ Solve-free alternative to standard PSGD-Kron (EQ method).
+ Reference: https://github.com/lixilinx/psgd_torch
+ """
+
+ cached: bool = False
+ exp_avg_input: bool = True
+
+ def __init__(
+ self,
+ params,
+ lr=0.001,
+ beta=None,
+ betas=(0.9, 0.999),
+ weight_decay=0.0,
+ preconditioner_update_probability=C.use_default,
+ max_size_triangular=2048,
+ min_ndim_triangular=2,
+ memory_save_mode=None,
+ momentum_into_precond_update=True,
+ warmup_steps: int = 0,
+ merge_dims: bool = False,
+ split: bool = False,
+ foreach: bool = True,
+ q_dtype="float32",
+ stochastic_schedule: bool = False,
+ storage_dtype: str = "float32",
+ mars: bool = False,
+ caution: bool = False,
+ mars_gamma: float = 0.0025,
+ cached: Optional[bool] = C.use_default,
+ exp_avg_input: Optional[bool] = C.use_default,
+ gradient_clipping: C.str_or_fn = C.use_default,
+ update_clipping: C.str_or_fn = C.use_default,
+ precond_grad_accum: bool = False,
+ lower_bound_beta: float = 0.9,
+ dampening: float = 2**-13,
+ precond_update_power_iterations: int = 2,
+ precond_init_scale=None,
+ precond_init_scale_scale: float = 1,
+ precond_init_scale_power: Optional[float] = None,
+ precond_lr: float = 0.1,
+ compile_step: bool = C.use_default,
+ promote: bool = C.use_default,
+ ecc: str | None = None,
+ param_ecc: str | None = None,
+ **kwargs,
+ ):
+ cached = C.default(cached, self.cached)
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
+
+ params, defaults = C._build_defaults(locals())
+ defaults["store_triu_as_line"] = False
+ defaults["inverse_free"] = False
+
+ self.precond_schedule = C.default(
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
+ )
+
+ super().__init__(
+ params,
+ defaults,
+ foreach,
+ gradient_clipping,
+ update_clipping,
+ False,
+ fns=(
+ *(C.exp_avg,) * exp_avg_input,
+ functools.partial(C.scale_by_psgd_pro, cached=cached),
+ ),
+ )
+
+
class ForeachPSGDLRA(C.BaseOpt):
"""
Originally from Evan Walters and Omead Pooladzandi, 2024
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index eca7ea7..0df6279 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -792,6 +792,27 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
state["Q_cache"] = [torch.empty_like(q) for q in Q]
+def _init_psgd_pro_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
+ Q = utils.init_Q_exprs(
+ grad,
+ group["precond_init_scale"],
+ group["precond_init_scale_scale"],
+ group["precond_init_scale_power"],
+ group["max_size_triangular"],
+ group["min_ndim_triangular"],
+ group["memory_save_mode"],
+ None,
+ None,
+ dtype=getattr(torch, group["q_dtype"]),
+ )
+ state["Q"] = Q
+ state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64)
+ if not cached:
+ return
+ state["Q_cache"] = [torch.empty_like(q) for q in Q]
+
+
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
state["U"], state["V"], state["d"] = utils.init_lra(
grad,
@@ -1094,6 +1115,56 @@ def _update_psgd_precond(
return None
+def _update_psgd_pro_precond(
+ cached,
+ Q_cache,
+ group,
+ param,
+ grad,
+ Q,
+ running_lower_bound,
+ step,
+ prob: Optional[callable] = None,
+) -> None:
+ if prob is None:
+ prob = utils.precond_update_prob_schedule()
+
+ if not group["is_preconditioning"]:
+ return
+
+ utils.psgd_pro_update_precond(
+ grad,
+ group["precond_lr"],
+ Q,
+ running_lower_bound,
+ group["lower_bound_beta"],
+ group["precond_update_power_iterations"],
+ group["dampening"],
+ )
+
+ if isinstance(prob, float):
+ float_prob = prob
+ else:
+ float_prob = prob(group["step"])
+ group["is_cached"] = should_use_cache = cached and float_prob < 0.5
+
+ if not should_use_cache or not cached:
+ return
+
+ for i, (c_, q_) in enumerate(zip(Q_cache, Q)):
+ if c_ is None:
+ c_ = (
+ torch.empty_like(q_)
+ if q_.ndim == 1
+ else torch.empty(q_.shape[0], q_.shape[0], device=q_.device, dtype=q_.dtype)
+ )
+ Q_cache[i] = c_
+ if q_.ndim == 2:
+ torch.matmul(q_.T, q_, out=c_)
+ else:
+ torch.mul(q_, q_, out=c_)
+
+
def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
kwargs = {"ea": update, "caution": group["caution"], "grad": grad}
if group.get("is_cached", False) and Q_cache[0] is not None:
@@ -1297,6 +1368,51 @@ def update_by_delayed_psgd(
raise SkipUpdate from None
+@needs_full_param
+@SqueezeGrad
+@PrecondGradAccumGuard
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
+@no_state_no_foreach
+def scale_by_psgd_pro(
+ group,
+ update,
+ grad,
+ param,
+ update_to_precond,
+ Q,
+ Q_cache,
+ running_lower_bound: List[Tensor],
+ step: Tensor,
+ cached: bool = False,
+ prob: Optional[callable] = None,
+):
+ _update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
+ return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
+
+
+@needs_full_param
+@SqueezeGrad
+@PrecondGradAccumGuard
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
+@no_state_no_foreach
+def update_by_psgd_pro(
+ group,
+ update,
+ grad,
+ param,
+ update_to_precond,
+ Q,
+ Q_cache,
+ running_lower_bound: List[Tensor],
+ step: Tensor,
+ cached: bool = False,
+ prob: Optional[callable] = None,
+):
+ _update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
+ _fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
+ raise SkipUpdate from None
+
+
def palm_beta2(state, group, update, grad, param):
beta2 = 1 - group["step"] ** -group["beta2_scale"]
group["betas"] = (utils.get_beta1(group), beta2)
diff --git a/heavyball/utils.py b/heavyball/utils.py
index f935fca..984cf17 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -2525,9 +2525,10 @@ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
@decorator_knowngood
def dampen_grad(g: Tensor, damp: float = 2**-13):
- # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
+ # https://github.com/lixilinx/psgd_torch/blob/89b4cead31b7ad1494c4cf4dc39f4cbf920ff14d/psgd.py
v = torch.randn_like(g)
- return v, g + damp * g.abs().mean() * v
+ damping = damp + torch.finfo(g.dtype).eps * g.abs()
+ return v, g + damping * v
@decorator_knowngood
@@ -2768,6 +2769,44 @@ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False,
return max_singular_value_power_iter(A, None, iterations=power_iter)
+@decorator_knowngood
+def max_eigenvalue_spd(A_outer: Tensor, power_iter: int = 4) -> Tensor:
+ """Power iteration for the largest eigenvalue of a symmetric positive (semi)definite matrix.
+ Exploits A = A^T: A^T A = A^2, so v -> A^T(Av) = v -> A(Av), saving a transpose.
+ Uses x @ A.mT (gemm transB=true) for faster BLAS dispatch than A.mv(x)."""
+ if A_outer.ndim < 2:
+ return A_outer.max()
+ x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
+ x_norm = promote(x_norm)
+
+ def _inner():
+ x = A_outer.index_select(0, max_idx).flatten().contiguous()
+ A = stochastic_round_(A_outer / x_norm)
+ x = x / x_norm
+
+ def _mv(x):
+ return promote((x.to(A.dtype) @ A.mT) @ A.mT)
+
+ for _ in range(power_iter):
+ x = F.normalize(_mv(x), dim=0)
+ return (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
+
+ return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone()).squeeze()
+
+
+@decorator_knowngood
+def procrustes_step(Q: Tensor, max_step_size: float = 1 / 8) -> None:
+ R = (Q.T - Q).contiguous()
+ R_norm = max_singular_value(R, power_iter=2) + torch.finfo(R.dtype).smallest_normal
+ R = R / R_norm
+ RQ = R @ Q
+ RRQ = R @ RQ
+ tr_RQ = RQ.diagonal().sum()
+ tr_RRQ = RRQ.diagonal().sum()
+ a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, max=max_step_size), max_step_size)
+ Q.add_(a * (RQ + 0.5 * a * RRQ))
+
+
@decorator_knowngood
def clamped_max_singular_value(
A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
@@ -2927,22 +2966,11 @@ def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
return coeff0.float(), coeffs.float()
-@decorator_knowngood
-def _psgd_default_preconditioner_grad(
- terms: List[Tuple[Tensor, Tensor]],
- Q: List[Tensor],
-) -> List[Tensor]:
- out = []
- for q, (x, y) in zip(Q, terms):
- x = promote(x)
- y = promote(y)
- update = x - y
- if q.ndim < 2:
- update = promote(q) * update
- else:
- update = (promote(q) @ update).triu()
- out.append(update)
- return out
+def _update_lb(ell: Tensor, lb_state: Tensor, beta: Tensor) -> Tensor:
+ ell = promote(ell)
+ ell = ell.maximum(promote(lb_state) + (ell - promote(lb_state)) * (1 - beta))
+ copy_stochastic_(lb_state, ell)
+ return ell
@decorator
@@ -2965,15 +2993,61 @@ def psgd_update_precond(
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
A, conjB = psgd_calc_A_and_conjB(G, Q, V)
- terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
- del A, conjB, V
- updates = _psgd_default_preconditioner_grad(terms, Q)
- _psgd_precond_update_(
- updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
- )
+ del V
+
+ for oq_i, q, exprG, lb_state in zip(oq, Q, exprGs, running_lower_bound):
+ term1 = promote(compiled_einsum(exprG, A, A))
+ term2 = promote(compiled_einsum(exprG, conjB, conjB))
+
+ if q.ndim < 2:
+ ell = _update_lb((term1 + term2).max(), lb_state, lower_bount_beta)
+ update = promote(q) * (term1 - term2)
+ else:
+ ell = _update_lb(max_eigenvalue_spd(term1 + term2, power_iter=power_iter), lb_state, lower_bount_beta)
+ update = (term1 - term2).triu() @ promote(q)
+ if store_triu_as_line:
+ update = triu_to_line([update])[0][1]
+
+ real_oq = oq_i[1] if isinstance(oq_i, tuple) else oq_i
+ copy_stochastic_(real_oq, promote(real_oq) - update / ell * precond_lr)
return None
+@decorator
+def psgd_pro_update_precond(
+ G: Tensor,
+ precond_lr: float,
+ Q: List[Tensor],
+ running_lower_bound: List[Tensor],
+ lower_bount_beta: float,
+ power_iter: int,
+ dampening: float,
+) -> None:
+ """Update Kronecker product preconditioner Q with Q0.5EQ1.5 (PRO) method."""
+ psgd_balance_Q(Q)
+ exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
+ precond_lr, lower_bount_beta = scalar_guard(precond_lr, lower_bount_beta, G)
+
+ damping = dampening + torch.finfo(G.dtype).eps * G.abs()
+ Pg = psgd_precond_grad(G + damping * torch.randn_like(G), Q)
+
+ total_numel = G.numel()
+ for q, exprG, lb_state in zip(Q, exprGs, running_lower_bound):
+ term1 = promote(compiled_einsum(exprG, Pg, Pg))
+ q_ = promote(q)
+
+ if q.ndim < 2:
+ term2 = total_numel / max(1, q.numel())
+ ell = _update_lb(term1.max() + term2, lb_state, lower_bount_beta)
+ copy_stochastic_(q, q_ - q_ * (term1 - term2) / ell * precond_lr)
+ else:
+ term2 = total_numel / q.shape[0]
+ ell = _update_lb(max_eigenvalue_spd(term1, power_iter=power_iter) + term2, lb_state, lower_bount_beta)
+ copy_stochastic_(q, q_ - (term1 @ q_ - term2 * q_) / ell * precond_lr)
+ procrustes_step(q)
+ del Pg
+
+
@decorator_knowngood
def bf16_matmul(x: Tensor, y: Tensor):
return (promote(x) @ promote(y)).to(x.dtype)
From 450d1b2276f07dbc6a86b17cf86c23de08f0cfb8 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Sat, 28 Mar 2026 17:04:34 +0000
Subject: [PATCH 02/24] add HyperBallAdamW, MuonAdamW
---
heavyball/__init__.py | 84 +++++++++++++++++++++++++++++-
heavyball/chainable.py | 113 ++++++++++++++++++++++++++++++++++++++---
heavyball/utils.py | 46 ++++++++++++++---
3 files changed, 227 insertions(+), 16 deletions(-)
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index 784cfc1..379acd6 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -303,6 +303,86 @@ def __init__(
)
+class HyperBallAdamW(C.BaseOpt):
+ def __init__(
+ self,
+ params,
+ lr=0.0025,
+ betas=(0.9, 0.99),
+ eps=1e-8,
+ weight_decay=0,
+ warmup_steps=0,
+ foreach: bool = True,
+ storage_dtype: str = "float32",
+ mars: bool = False,
+ caution: bool = False,
+ mars_gamma: float = 0.0025,
+ gradient_clipping: C.str_or_fn = C.use_default,
+ update_clipping: C.str_or_fn = C.use_default,
+ palm: bool = C.use_default,
+ beta2_scale: float = 0.8,
+ compile_step: bool = C.use_default,
+ promote: bool = C.use_default,
+ ecc: str | None = None,
+ param_ecc: str | None = None,
+ **kwargs,
+ ):
+ params, defaults = C._build_defaults(locals())
+ super().__init__(
+ params,
+ defaults,
+ foreach,
+ gradient_clipping,
+ update_clipping,
+ palm,
+ fns=(C.scale_by_exp_avg_sq, C.route(
+ (lambda p: p.ndim >= 2, C.update_by_hyperball),
+ default=C.apply_update,
+ )),
+ )
+
+
+class MuonAdamW(C.BaseOpt):
+ def __init__(
+ self,
+ params,
+ lr=0.0025,
+ betas=(0.9, 0.99),
+ eps=1e-8,
+ weight_decay=0,
+ warmup_steps=0,
+ foreach: bool = True,
+ storage_dtype: str = "float32",
+ mars: bool = False,
+ caution: bool = False,
+ mars_gamma: float = 0.0025,
+ gradient_clipping: C.str_or_fn = C.use_default,
+ update_clipping: C.str_or_fn = C.use_default,
+ palm: bool = C.use_default,
+ beta2_scale: float = 0.8,
+ nesterov: bool = True,
+ compile_step: bool = C.use_default,
+ promote: bool = C.use_default,
+ ecc: str | None = None,
+ param_ecc: str | None = None,
+ **kwargs,
+ ):
+ params, defaults = C._build_defaults(locals())
+ ema = C.nesterov_ema if nesterov else C.exp_avg
+ super().__init__(
+ params,
+ defaults,
+ foreach,
+ gradient_clipping,
+ update_clipping,
+ palm,
+ fns=(C.route(
+ (lambda p: p.ndim >= 2, (ema, C.orthogonalize_update)),
+ default=C.scale_by_adam,
+ ),),
+ )
+
+
class ForeachSFAdamW(C.ScheduleFree):
def __init__(
self,
@@ -951,7 +1031,7 @@ def __init__(
precond_grad_accum: bool = False,
lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD
inverse_free: bool = C.use_default,
- dampening: float = 2**-13,
+ dampening: float = 1e-9,
precond_update_power_iterations: int = 2,
# expert parameters
precond_init_scale=None,
@@ -1061,7 +1141,7 @@ def __init__(
update_clipping: C.str_or_fn = C.use_default,
precond_grad_accum: bool = False,
lower_bound_beta: float = 0.9,
- dampening: float = 2**-13,
+ dampening: float = 1e-9,
precond_update_power_iterations: int = 2,
precond_init_scale=None,
precond_init_scale_scale: float = 1,
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index 0df6279..dc16ce0 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -85,20 +85,94 @@ def __repr__(self):
return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
-class Branch:
+def _enforce_uniform_skip(results):
+ skips = [skip for _, skip, _ in results]
+ if not skips:
+ return False
+ if all(skips):
+ return True
+ if not any(skips):
+ return False
+ raise ValueError("All branches must uniformly skip or not skip updates")
+
+
+def _normalize_chain(fns):
+ if fns is None:
+ return None
+ return fns if isinstance(fns, (list, tuple)) else (fns,)
+
+
+class Parallel:
def __init__(self, branches: List[List[callable]], merge_fn: callable):
self.branches = branches
self.merge_fn = merge_fn
def __call__(self, state, group, update, grad, param):
- outputs = []
+ results = []
for branch in self.branches:
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
- branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
- if skip_update:
- raise ValueError("Branches should not skip updates")
- outputs.append(branch_update)
- return self.merge_fn(outputs)
+ u, skip = _inner_chain(state, group, branch_update, grad, param, *branch)
+ results.append((u, skip, None))
+ if _enforce_uniform_skip(results):
+ raise SkipUpdate from None
+ return self.merge_fn([u for u, _, _ in results])
+
+
+class Route:
+ """Route params by predicate through different fn chains.
+
+ Takes arbitrary (predicate, fns) pairs. Each param is assigned to the first
+ matching route; unmatched params use the default chain (None = passthrough).
+ All routes must uniformly either skip or not skip updates.
+ """
+
+ def __init__(self, *routes, default=None):
+ self.routes = [(pred, _normalize_chain(fns)) for pred, fns in routes]
+ self.default = _normalize_chain(default)
+
+ def __call__(self, state, group, update, grad, param):
+ buckets = {}
+ assigned = set()
+ for j, (pred, _) in enumerate(self.routes):
+ for i, p in enumerate(param):
+ if i not in assigned and pred(p):
+ buckets.setdefault(j, []).append(i)
+ assigned.add(i)
+ default_idx = [i for i in range(len(param)) if i not in assigned]
+
+ def _sel(lst, idx):
+ return [lst[i] for i in idx]
+
+ caution = group["caution"]
+ results = []
+
+ all_chains = [(buckets.get(j), fns) for j, (_, fns) in enumerate(self.routes)]
+ if default_idx:
+ all_chains.append((default_idx, self.default))
+
+ for idx, fns in all_chains:
+ if not idx:
+ continue
+ group["caution"] = caution
+ if fns is not None:
+ u, skip = _inner_chain(state, group, _sel(update, idx), _sel(grad, idx), _sel(param, idx), *fns)
+ else:
+ u, skip = _sel(update, idx), False
+ results.append((u, skip, idx))
+
+ if _enforce_uniform_skip(results):
+ raise SkipUpdate from None
+
+ out = [None] * len(param)
+ for u_list, _, idx in results:
+ if u_list is not None:
+ for i, u in zip(idx, u_list):
+ out[i] = u
+ return out
+
+
+def route(*routes, default=None):
+ return Route(*routes, default=default)
def _zero_guard(state, key, ref, dtype):
@@ -405,6 +479,12 @@ def identity(state, group, update, grad, param):
return update
+@no_state
+def apply_update(group, update, grad, param):
+ utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
+ raise SkipUpdate from None
+
+
@zero_guard("exp_avg")
@no_state
def weight_decay_to_ema(group, update, grad, param, exp_avg):
@@ -867,6 +947,18 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
+def _store_init_norm(state, group, update, grad, param):
+ state["init_norm"] = param.float().norm()
+
+
+@needs_full_param
+@general_guard("init_norm", init_fn=_store_init_norm, skip_first=False)
+@no_state
+def update_by_hyperball(group, update, grad, param, init_norm):
+ utils.hyperball_step_(param, update, init_norm, group["lr"], group["weight_decay"], group["caution"], grad)
+ raise SkipUpdate from None
+
+
def _store_std(state, group, update, grad, param):
state["init_std"] = torch.std(param)
@@ -1616,9 +1708,14 @@ def _walk_fns(obj):
stack.append(cur.fn)
elif isinstance(cur, functools.partial):
stack.append(cur.func)
- elif isinstance(cur, Branch):
+ elif isinstance(cur, Parallel):
for branch in cur.branches:
stack.extend(branch)
+ elif isinstance(cur, Route):
+ for _, fns in cur.routes:
+ stack.extend(fns)
+ if cur.default is not None:
+ stack.extend(cur.default)
elif isinstance(cur, _Iterable) and not isinstance(cur, (str, bytes, bytearray)):
stack.extend(cur)
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 984cf17..7cb8675 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -1504,7 +1504,7 @@ def _finite_differences_hvp(self, closure):
p.vector = torch.randn_like(p)
p.orig = p.data.clone()
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
- stochastic_add_(p.data, p.vector, torch.finfo(p.dtype).eps ** 0.5)
+ stochastic_add_(p.data, p.vector, torch.finfo(torch.float32).eps ** 0.5)
with torch.enable_grad():
closure()
@@ -1514,7 +1514,7 @@ def _finite_differences_hvp(self, closure):
for group in self.param_groups:
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
p.grad = grads.pop(0)
- stochastic_add_divide_(g, p.grad, -1, torch.finfo(p.dtype).eps ** 0.5)
+ stochastic_add_divide_(g, p.grad, -1, torch.finfo(torch.float32).eps ** 0.5)
p.hessian_vector = g
p.data.copy_(p.orig)
del p.orig
@@ -2524,10 +2524,9 @@ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
@decorator_knowngood
-def dampen_grad(g: Tensor, damp: float = 2**-13):
- # https://github.com/lixilinx/psgd_torch/blob/89b4cead31b7ad1494c4cf4dc39f4cbf920ff14d/psgd.py
+def dampen_grad(g: Tensor, damp: float = 1e-9):
v = torch.randn_like(g)
- damping = damp + torch.finfo(g.dtype).eps * g.abs()
+ damping = damp + torch.finfo(torch.float32).eps * g.abs()
return v, g + damping * v
@@ -3028,7 +3027,7 @@ def psgd_pro_update_precond(
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
precond_lr, lower_bount_beta = scalar_guard(precond_lr, lower_bount_beta, G)
- damping = dampening + torch.finfo(G.dtype).eps * G.abs()
+ damping = dampening + torch.finfo(torch.float32).eps * G.abs()
Pg = psgd_precond_grad(G + damping * torch.randn_like(G), Q)
total_numel = G.numel()
@@ -3693,6 +3692,41 @@ def caution(g, update):
return _compilable_cautioning(g, update)
+@decorator_knowngood
+def _compilable_hyperball_(
+ p: List[Tensor],
+ u: List[Tensor],
+ init_norm: List[Tensor],
+ lr: Tensor,
+ decay: float,
+ caution: bool,
+ g: List[Tensor],
+):
+ for op, u_, n_, g_ in zip(p, u, init_norm, g):
+ u_ = promote(u_.view_as(op))
+ p_ = promote(op)
+ if decay != 0:
+ u_ = u_ + p_ * decay
+ if caution:
+ u_ = _compilable_cautioning(promote(g_), u_)
+ u_norm = u_.norm()
+ u_norm = u_norm.clamp(min=1e-12)
+ u_ = u_ / u_norm
+ p_ = p_ - lr * u_ * n_
+ p_norm = p_.norm()
+ p_norm = p_norm.clamp(min=1e-12)
+ p_ = p_ * (n_ / p_norm)
+ copy_stochastic_(op, p_)
+
+
+def hyperball_step_(param, update, init_norm, lr, decay, caution, grad):
+ param, update, init_norm, grad = list_guard(param, update, init_norm, grad)
+ lr = scalar_guard(lr, param[0])
+ if not caution:
+ grad = [None] * len(param)
+ _compilable_hyperball_(param, update, init_norm, lr, decay, caution, grad)
+
+
def _inner_precond_update_prob_schedule(
n: int, max_prob: float = 1.0, min_prob: float = 0.03, decay: float = 0.999, flat_start: float = 1000
):
From bff2265ee7cef70d6eebb4dd376233afbd9cd177 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Sat, 28 Mar 2026 18:13:56 +0000
Subject: [PATCH 03/24] rename everything
---
README.md | 20 +--
docs/benchmark.md | 4 +-
docs/heavyball2.md | 4 +-
examples/branched_optimizer.py | 4 +-
examples/ddp_training.py | 6 +-
examples/ecc_bf16.py | 6 +-
examples/fsdp_training.py | 8 +-
heavyball/__init__.py | 235 ++++++++++-------------------
heavyball/chainable.py | 40 ++---
heavyball/utils.py | 29 ++--
scripts/migrate_optimizer_state.py | 2 +-
test/test_ademamix.py | 6 +-
test/test_chainable_cpu.py | 50 +++---
test/test_cpu_features.py | 10 +-
test/test_distributed.py | 19 +--
test/test_ecc.py | 82 +++++-----
test/test_foreach.py | 18 +--
test/test_memory_leak.py | 2 +-
test/test_merge.py | 2 +-
test/test_migrate_cli.py | 6 +-
test/test_optimizer_cpu_smoke.py | 2 +-
test/test_param_ecc_compile.py | 2 +-
test/test_toy_training.py | 14 +-
test/utils.py | 12 +-
24 files changed, 241 insertions(+), 342 deletions(-)
diff --git a/README.md b/README.md
index d9ba71f..56b5359 100644
--- a/README.md
+++ b/README.md
@@ -55,29 +55,25 @@ training, and SAM.
Full list
**First-order:**
-AdamW, NAdam, RMSprop, ADOPT, ForeachAdEMAMix, LaProp, SignLaProp, SGD, Scion, UnscaledAdamW, ForeachAdamC, SUDSAdamW
+AdamW, NAdam, RMSprop, ADOPT, AdEMAMix, LaProp, SignLaProp, SGD, Scion, UnscaledAdamW, AdamC, SUDSAdamW
**Schedule-Free:**
-SFAdamW, PaLMSFAdamW
+SFAdamW
Schedule-Free optimizers override `.eval()` and `.train()` to swap between training and evaluation parameter states.
Call `opt.eval()` before validation and `opt.train()` before resuming training.
**Orthogonal:**
-Muon, MuonLaProp, OrthoLaProp, LaPropOrtho
+Muon, MuonAdamW, MuonLaProp, HyperBallAdamW, OrthoLaProp, LaPropOrtho
**Shampoo-based (SOAP):**
-SOAP, PaLMSOAP, PrecondScheduleSOAP, PrecondSchedulePaLMSOAP, SOAPNAdam, SOAPAdEMAMix, ForeachSOLP
+SOAP, SOAPNAdam, SOAPAdEMAMix, SOLP
**PSGD (Kronecker):**
-PSGDPRO, PSGDKron, CachedPSGDKron, DelayedPSGD, CachedDelayedPSGDKron, PurePSGD, NewtonPSGDKron, NewtonHybrid2PSGDKron
-
-`Newton`-PSGD requires a closure passed to `step()`.
+PSGDKron, PSGDPRO
**PSGD (Low-Rank):**
-PSGDLRA, DelayedPSGDLRA, NewtonPSGDLRA, NewtonHybrid2PSGDLRA
-
-`Newton`-PSGD requires a closure passed to `step()`.
+PSGDLRA
**SAM:**
SAMWrapper, MSAMLaProp
@@ -169,11 +165,11 @@ def graft(outputs, eps=1e-8):
class GraftedAdam(C.BaseOpt):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, warmup_steps=0, foreach=True):
+ weight_decay=0, warmup_steps=0, multi_tensor=True):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
warmup_steps=warmup_steps)
branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft)
- super().__init__(params, defaults, foreach, fns=(branch,))
+ super().__init__(params, defaults, multi_tensor, fns=(branch,))
```
Custom optimizers that inherit from `BaseOpt` get ECC, MARS, caution, clipping, warmup, and stochastic rounding
diff --git a/docs/benchmark.md b/docs/benchmark.md
index 398b051..5fdb320 100644
--- a/docs/benchmark.md
+++ b/docs/benchmark.md
@@ -56,9 +56,9 @@ reinforcing the need for diagnostic rather than purely comparative evaluation.
| Optimizer | Cautious¹ | Mars² | Success | Attempts | Avg Runtime (s) |
|:---------------|:----------|:------|:--------|:---------|:----------------|
| PSGDKron | No | No | 77.0% | 73.2 | 8240 |
-| NewtonPSGDKron | No | No | 77.0% | 80.5 | 9052 |
+| PSGDKron (Newton) | No | No | 77.0% | 80.5 | 9052 |
| AdamW | Yes | No | 75.7% | 61.2 | 8072 |
-| ForeachSOAP | No | No | 72.5% | 77.9 | 7827 |
+| SOAP | No | No | 72.5% | 77.9 | 7827 |
| AdamW | No | No | 72.3% | 107.8 | 10029 |
| MuonLaProp | No | No | 68.2% | 82.7 | 10141 |
| RMSprop | No | No | 55.6% | 114.4 | 10725 |
diff --git a/docs/heavyball2.md b/docs/heavyball2.md
index 5660cb1..7bc8555 100644
--- a/docs/heavyball2.md
+++ b/docs/heavyball2.md
@@ -4,7 +4,7 @@
* First‑class SAM via `SAMWrapper` (closure‑based)
* More robust checkpoint/restore with HeavyBall‑internal state
-* New optimizers: `SGD`, `ForeachAdamC`, `MSAMLaProp`
+* New optimizers: `SGD`, `AdamC`, `MSAMLaProp`
* Overhauled chainable pipeline: indexed transforms, branching, internal gradient‑accumulation, and `SqueezeGrad`
* Faster, more accurate code paths
* New `heavyball.helpers` with Optuna‑compatible samplers and utilities
@@ -18,7 +18,7 @@
* `SAMWrapper` applies sharpness‑aware minimization to any HeavyBall optimizer while preserving the wrapped step logic;
requires a closure
* `SGD` built on the chainable internals
-* `ForeachAdamC`, a ["corrected version of Adam"](https://arxiv.org/abs/2506.02285) with weight decay normalized by the
+* `AdamC`, a ["corrected version of Adam"](https://arxiv.org/abs/2506.02285) with weight decay normalized by the
maximum LR
* `MSAMLaProp` built on top of [Momentum‑SAM](https://arxiv.org/abs/2401.12033)
* Chainable pipeline:
diff --git a/examples/branched_optimizer.py b/examples/branched_optimizer.py
index b8a2795..001f107 100644
--- a/examples/branched_optimizer.py
+++ b/examples/branched_optimizer.py
@@ -29,7 +29,7 @@ def __init__(
eps: float = 1e-8,
weight_decay: float = 1e-4,
warmup_steps: int = 0,
- foreach: bool = True,
+ multi_tensor: bool = True,
):
defaults = dict(
lr=lr,
@@ -39,7 +39,7 @@ def __init__(
warmup_steps=warmup_steps,
)
branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=_graft)
- super().__init__(params, defaults, foreach, fns=(branch,))
+ super().__init__(params, defaults, multi_tensor, fns=(branch,))
def main(epochs: int = 20, batch_size: int = 256, subset_size: int = 4096):
diff --git a/examples/ddp_training.py b/examples/ddp_training.py
index dc27092..bc39932 100644
--- a/examples/ddp_training.py
+++ b/examples/ddp_training.py
@@ -2,8 +2,8 @@
Launch with torchrun:
torchrun --nproc_per_node=2 examples/ddp_training.py
- torchrun --nproc_per_node=2 examples/ddp_training.py --opt ForeachSOAP
- torchrun --nproc_per_node=2 examples/ddp_training.py --opt ForeachMuon --lr 0.01
+ torchrun --nproc_per_node=2 examples/ddp_training.py --opt SOAP
+ torchrun --nproc_per_node=2 examples/ddp_training.py --opt Muon --lr 0.01
All HeavyBall optimizers work transparently with DDP
"""
@@ -38,7 +38,7 @@ def make_model():
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--opt", default="ForeachAdamW")
+ parser.add_argument("--opt", default="AdamW")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
diff --git a/examples/ecc_bf16.py b/examples/ecc_bf16.py
index 1b8d1bf..7effd5d 100644
--- a/examples/ecc_bf16.py
+++ b/examples/ecc_bf16.py
@@ -19,11 +19,11 @@
CONFIGS = {
"naive_fp32": lambda p: NaiveAdamW(p, lr=LR, betas=BETAS, eps=EPS, state_dtype=torch.float32),
"naive_bf16": lambda p: NaiveAdamW(p, lr=LR, betas=BETAS, eps=EPS, state_dtype=torch.bfloat16),
- "heavyball_fp32": lambda p: heavyball.ForeachAdamW(p, lr=LR, betas=BETAS, eps=EPS, storage_dtype="float32"),
- "heavyball_bf16": lambda p: heavyball.ForeachAdamW(
+ "heavyball_fp32": lambda p: heavyball.AdamW(p, lr=LR, betas=BETAS, eps=EPS, storage_dtype="float32"),
+ "heavyball_bf16": lambda p: heavyball.AdamW(
p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, storage_dtype="bfloat16"
),
- "ecc_bf16+8": lambda p: heavyball.ForeachAdamW(p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, ecc="bf16+8"),
+ "ecc_bf16+8": lambda p: heavyball.AdamW(p, lr=LR, betas=BETAS, eps=EPS, weight_decay=0, ecc="bf16+8"),
}
COLORS = {
diff --git a/examples/fsdp_training.py b/examples/fsdp_training.py
index 375d0e8..06087de 100644
--- a/examples/fsdp_training.py
+++ b/examples/fsdp_training.py
@@ -2,8 +2,8 @@
Launch with torchrun:
torchrun --nproc_per_node=2 examples/fsdp_training.py
- torchrun --nproc_per_node=2 examples/fsdp_training.py --opt ForeachSOAP
- torchrun --nproc_per_node=2 examples/fsdp_training.py --opt ForeachMuon --lr 0.01
+ torchrun --nproc_per_node=2 examples/fsdp_training.py --opt SOAP
+ torchrun --nproc_per_node=2 examples/fsdp_training.py --opt Muon --lr 0.01
Shape-aware optimizers (SOAP, Muon, PSGD, Scion, etc.) auto-detect FSDP-flattened
params and restore original shapes. No manual intervention needed.
@@ -11,7 +11,7 @@
For non-FSDP parallelism backends, capture shapes before wrapping:
shapes = heavyball.capture_param_shapes(model)
model = your_wrapper(model)
- opt = heavyball.ForeachSOAP(model.parameters(), lr=3e-3, orig_shapes=shapes)
+ opt = heavyball.SOAP(model.parameters(), lr=3e-3, orig_shapes=shapes)
"""
import argparse
@@ -44,7 +44,7 @@ def make_model():
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--opt", default="ForeachAdamW")
+ parser.add_argument("--opt", default="AdamW")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--lr", type=float, default=1e-3)
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index 379acd6..476aac0 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -16,7 +16,7 @@ def __init__(
beta=0.9,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -30,10 +30,10 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
-class ForeachAdamW(C.BaseOpt):
+class AdamW(C.BaseOpt):
def __init__(
self,
params,
@@ -42,7 +42,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -58,10 +58,10 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
-class ForeachNAdam(C.BaseOpt):
+class NAdam(C.BaseOpt):
def __init__(
self,
params,
@@ -72,7 +72,7 @@ def __init__(
momentum_decay: float = 4e-3,
decoupled_weight_decay: bool = False,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -88,10 +88,10 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,))
-class ForeachAdEMAMix(C.BaseOpt):
+class AdEMAMix(C.BaseOpt):
def __init__(
self,
params,
@@ -103,7 +103,7 @@ def __init__(
beta3_warmup: Optional[int] = None,
alpha_warmup: Optional[int] = None,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -120,7 +120,7 @@ def __init__(
raise ValueError("AdEMAMix expects betas with three coefficients.")
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,))
class UnscaledAdamW(C.BaseOpt):
@@ -132,7 +132,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -149,7 +149,7 @@ def __init__(
):
params, defaults = C._build_defaults(locals())
super().__init__(
- params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
+ params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
)
@@ -162,7 +162,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -179,7 +179,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
class Scion(C.BaseOpt):
@@ -191,7 +191,7 @@ def __init__(
eps: float = 1e-8,
weight_decay: float = 0,
warmup_steps: int = 0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -222,11 +222,11 @@ def __init__(
defaults.pop("momentum", None)
super().__init__(
- params, defaults, foreach, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm)
+ params, defaults, multi_tensor, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm)
)
-class ForeachAdamC(C.BaseOpt):
+class AdamC(C.BaseOpt):
def __init__(
self,
params,
@@ -236,7 +236,7 @@ def __init__(
weight_decay=0,
max_lr: float | None = None,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -258,10 +258,10 @@ def __init__(
max_lr = lr
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
-class ForeachRMSprop(C.BaseOpt):
+class RMSprop(C.BaseOpt):
"""
Debiased RMSprop (not torch.optim.RMSprop)
"""
@@ -276,7 +276,7 @@ def __init__(
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -295,7 +295,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -312,7 +312,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -331,7 +331,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -351,7 +351,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -372,7 +372,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -383,7 +383,7 @@ def __init__(
)
-class ForeachSFAdamW(C.ScheduleFree):
+class SFAdamW(C.ScheduleFree):
def __init__(
self,
params,
@@ -394,7 +394,7 @@ def __init__(
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -413,7 +413,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -432,7 +432,7 @@ def __init__(
warmup_steps=0,
r=0.0,
weight_lr_power=2.0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -452,7 +452,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -460,11 +460,7 @@ def __init__(
)
-class PaLMForeachSFAdamW(ForeachSFAdamW):
- palm: bool = True
-
-
-class ForeachADOPT(C.BaseOpt):
+class ADOPT(C.BaseOpt):
def __init__(
self,
params,
@@ -473,7 +469,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -489,10 +485,10 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
-class ForeachMuon(C.BaseOpt):
+class Muon(C.BaseOpt):
def __init__(
self,
params,
@@ -501,7 +497,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -530,7 +526,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -538,7 +534,7 @@ def __init__(
)
-class ForeachLaProp(C.BaseOpt):
+class LaProp(C.BaseOpt):
def __init__(
self,
params,
@@ -547,7 +543,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -563,7 +559,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
+ super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
class MuonLaProp(C.BaseOpt):
@@ -575,7 +571,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -594,7 +590,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -602,9 +598,9 @@ def __init__(
)
-class ForeachSOAP(C.BaseOpt):
+class SOAP(C.BaseOpt):
"""
- ForeachSOAP
+ SOAP
Sources:
Baseline SOAP:
@@ -632,7 +628,7 @@ def __init__(
correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -664,7 +660,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm, #
@@ -672,7 +668,7 @@ def __init__(
)
-class ForeachSOAPNAdam(C.BaseOpt):
+class SOAPNAdam(C.BaseOpt):
use_precond_schedule: bool = False
def __init__(
@@ -691,7 +687,7 @@ def __init__(
correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -725,7 +721,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -733,7 +729,7 @@ def __init__(
)
-class ForeachSOAPAdEMAMix(C.BaseOpt):
+class SOAPAdEMAMix(C.BaseOpt):
use_precond_schedule: bool = False
def __init__(
@@ -752,7 +748,7 @@ def __init__(
correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -787,7 +783,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -795,7 +791,7 @@ def __init__(
)
-class ForeachSignLaProp(C.BaseOpt):
+class SignLaProp(C.BaseOpt):
def __init__(
self,
params,
@@ -804,7 +800,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -823,7 +819,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -831,9 +827,9 @@ def __init__(
)
-class ForeachSOLP(C.BaseOpt):
+class SOLP(C.BaseOpt):
"""
- ForeachSOLP
+ SOLP
Sources:
Baseline SOAP:
@@ -861,7 +857,7 @@ def __init__(
correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
mars: bool = False,
caution: bool = False,
mars_gamma: float = 0.0025,
@@ -892,7 +888,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm, #
@@ -900,20 +896,6 @@ def __init__(
)
-class PaLMForeachSOAP(ForeachSOAP):
- use_precond_schedule: bool = False
- palm: bool = True
-
-
-class PrecondScheduleForeachSOAP(ForeachSOAP):
- use_precond_schedule: bool = True
-
-
-class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
- use_precond_schedule: bool = True
- palm: bool = True
-
-
class OrthoLaProp(C.BaseOpt):
def __init__(
self,
@@ -923,7 +905,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -942,7 +924,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -959,7 +941,7 @@ def __init__(
eps=1e-8,
weight_decay=0,
warmup_steps=0,
- foreach: bool = True,
+ multi_tensor: bool = True,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -978,7 +960,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -986,7 +968,7 @@ def __init__(
)
-class ForeachPSGDKron(C.BaseOpt):
+class PSGDKron(C.BaseOpt):
"""
Originally from Evan Walters and Omead Pooladzandi, 2024
Modified under Creative Commons Attribution 4.0 International
@@ -1014,7 +996,7 @@ def __init__(
merge_dims: bool = False,
split: bool = False,
store_triu_as_line: bool = True,
- foreach: bool = True,
+ multi_tensor: bool = True,
q_dtype="float32",
stochastic_schedule: bool = False,
storage_dtype: str = "float32",
@@ -1067,7 +1049,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
False, #
@@ -1078,31 +1060,6 @@ def __init__(
)
-class ForeachPurePSGD(ForeachPSGDKron):
- exp_avg_input: bool = False
-
-
-class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
- delayed: bool = True
- cached: bool = True
-
-
-class ForeachCachedPSGDKron(ForeachPSGDKron):
- cached: bool = True
-
-
-class ForeachDelayedPSGD(ForeachPSGDKron):
- delayed: bool = True
-
-
-class ForeachCachedNewtonPSGD(ForeachCachedPSGDKron):
- hessian_approx = True
-
-
-class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
- hvp_interval = 2
-
-
class PSGDPRO(C.BaseOpt):
"""
PSGD with Q0.5EQ1.5 (PRO/Procrustes) preconditioner update.
@@ -1128,7 +1085,7 @@ def __init__(
warmup_steps: int = 0,
merge_dims: bool = False,
split: bool = False,
- foreach: bool = True,
+ multi_tensor: bool = True,
q_dtype="float32",
stochastic_schedule: bool = False,
storage_dtype: str = "float32",
@@ -1155,7 +1112,7 @@ def __init__(
):
cached = C.default(cached, self.cached)
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
+ update_clipping = C.default(update_clipping, None)
params, defaults = C._build_defaults(locals())
defaults["store_triu_as_line"] = False
@@ -1168,7 +1125,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
False,
@@ -1179,14 +1136,14 @@ def __init__(
)
-class ForeachPSGDLRA(C.BaseOpt):
+class PSGDLRA(C.BaseOpt):
"""
Originally from Evan Walters and Omead Pooladzandi, 2024
Modified under Creative Commons Attribution 4.0 International
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
- Note: `foreach=True` (default) uses a single global low-rank approximation shared across all
- parameters, while `foreach=False` fits an independent per-parameter LRA. These are different
+ Note: `multi_tensor=True` (default) uses a single global low-rank approximation shared across all
+ parameters, while `multi_tensor=False` fits an independent per-parameter LRA. These are different
algorithms and will produce different results.
"""
@@ -1203,7 +1160,7 @@ def __init__(
momentum_into_precond_update=True,
rank: Optional[int] = None,
warmup_steps: int = 0,
- foreach: bool = True, # True: global LRA across all params. False: independent per-param LRA.
+ multi_tensor: bool = True, # True: global LRA across all params. False: independent per-param LRA.
q_dtype="float32",
stochastic_schedule: bool = False,
storage_dtype: str = "float32",
@@ -1251,7 +1208,7 @@ def __init__(
super().__init__(
params,
defaults,
- foreach,
+ multi_tensor,
gradient_clipping,
update_clipping,
False, #
@@ -1259,18 +1216,6 @@ def __init__(
)
-class ForeachDelayedPSGDLRA(ForeachPSGDLRA):
- delayed: bool = True
-
-
-class ForeachNewtonPSGDLRA(ForeachPSGDLRA):
- hessian_approx = True
-
-
-class NewtonHybrid2PSGDLRA(ForeachNewtonPSGDLRA):
- hvp_interval = 2
-
-
class SplitOpt(utils.StatefulOptimizer):
"""
Delegates different parameter groups to different underlying optimizers.
@@ -1291,7 +1236,7 @@ def __init__(self, specs):
all_params.extend(params)
if not self.optimizers:
raise ValueError("No optimizers created")
- super().__init__(all_params, {}, foreach=True)
+ super().__init__(all_params, {}, multi_tensor=True)
def _step(self, group):
pass
@@ -1322,7 +1267,7 @@ class SAMWrapper(torch.optim.Optimizer):
def __init__(
self,
params,
- wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = ForeachAdamW,
+ wrapped_optimizer: Union[utils.StatefulOptimizer, Type[utils.StatefulOptimizer]] = AdamW,
ball: float = 0.1,
):
params = list(params)
@@ -1366,31 +1311,5 @@ def zero_grad(self, set_to_none: bool = True):
self.wrapped_optimizer.zero_grad(set_to_none=set_to_none)
-PalmForEachSoap = PaLMForeachSOAP
-PaLMSOAP = PaLMForeachSOAP
-PaLMSFAdamW = PaLMForeachSFAdamW
-SOAP = ForeachSOAP
-SOAPAdEMAMix = ForeachSOAPAdEMAMix
-SOAPNAdam = ForeachSOAPNAdam
-SFAdamW = ForeachSFAdamW
-LaProp = ForeachLaProp
-ADOPT = ForeachADOPT
-RMSprop = ForeachRMSprop
-PrecondScheduleSOAP = PrecondScheduleForeachSOAP
-PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
-PSGDKron = ForeachPSGDKron
-AdamW = ForeachAdamW
-NAdam = ForeachNAdam
-PurePSGD = ForeachPurePSGD
-DelayedPSGD = ForeachDelayedPSGD
-CachedPSGDKron = ForeachCachedPSGDKron
-CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
-Muon = ForeachMuon
-SignLaProp = ForeachSignLaProp
-DelayedPSGDLRA = ForeachDelayedPSGDLRA
-PSGDLRA = ForeachPSGDLRA
-NewtonPSGDLRA = ForeachNewtonPSGDLRA
-NewtonPSGDKron = ForeachCachedNewtonPSGD
-
capture_param_shapes = utils.capture_param_shapes
__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index dc16ce0..6dbf1f9 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -374,7 +374,7 @@ def __call__(self, state, group, update, grad, param, *args, **kwargs):
return self.fn(group, update, grad, param, *args, **kwargs)
-class NoStateNoForeach(FunctionTransform):
+class NoStateNoMultiTensor(FunctionTransform):
def __call__(self, state, group, update, grad, param, *args, **kwargs):
updates = []
skip_update = False
@@ -446,8 +446,8 @@ def no_state(fn):
return NoState(fn)
-def no_state_no_foreach(fn):
- return NoStateNoForeach(fn)
+def no_state_no_multi_tensor(fn):
+ return NoStateNoMultiTensor(fn)
class SkipUpdate(ValueError):
@@ -699,7 +699,7 @@ def orthogonalize_grad_to_param(group, update, grad, param):
@copy_guard(2, "z")
@no_state
def update_by_schedule_free(group, update, grad, param, z):
- # Compute weight_sum once per step, not per param in no-foreach mode.
+ # Compute weight_sum once per step, not per param in no-multi_tensor mode.
if group.get("_sf_step") != group["step"]:
weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"]
group["weight_sum"] = group.get("weight_sum", 0) + weight
@@ -779,7 +779,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
@needs_full_param
@zero_guard("exp_avg", "exp_avg_sq", "fisher_approx")
-@no_state_no_foreach
+@no_state_no_multi_tensor
def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
if group["step"] == 1:
utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
@@ -924,7 +924,7 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
@needs_full_param
-@no_state_no_foreach
+@no_state_no_multi_tensor
def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
if update.dim() < 2:
return update
@@ -1352,7 +1352,7 @@ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U,
@SqueezeGrad
@PrecondGradAccumGuard
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@no_state_no_multi_tensor
def scale_by_psgd(
group,
update,
@@ -1375,7 +1375,7 @@ def scale_by_psgd(
@SqueezeGrad
@PrecondGradAccumGuard
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@no_state_no_multi_tensor
def scale_by_delayed_psgd(
group,
update,
@@ -1404,7 +1404,7 @@ def scale_by_delayed_psgd(
@SqueezeGrad
@PrecondGradAccumGuard
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@no_state_no_multi_tensor
def update_by_psgd(
group,
update,
@@ -1440,7 +1440,7 @@ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
@SqueezeGrad
@PrecondGradAccumGuard
@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
-@no_state_no_foreach
+@no_state_no_multi_tensor
def update_by_delayed_psgd(
group,
update,
@@ -1464,7 +1464,7 @@ def update_by_delayed_psgd(
@SqueezeGrad
@PrecondGradAccumGuard
@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
-@no_state_no_foreach
+@no_state_no_multi_tensor
def scale_by_psgd_pro(
group,
update,
@@ -1486,7 +1486,7 @@ def scale_by_psgd_pro(
@SqueezeGrad
@PrecondGradAccumGuard
@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
-@no_state_no_foreach
+@no_state_no_multi_tensor
def update_by_psgd_pro(
group,
update,
@@ -1745,14 +1745,14 @@ class ChainOpt(utils.StatefulOptimizer):
"eps": 1e-8,
}
- def __init__(self, params, defaults, foreach: bool, *fns):
+ def __init__(self, params, defaults, multi_tensor: bool, *fns):
orig = defaults.pop("orig_shapes", None)
self._orig_shapes = (
{k: _ShapeInfo(v) if isinstance(v, tuple) else v for k, v in orig.items()} if orig is not None else None
)
base = self.global_defaults.copy()
base.update({k: v for k, v in defaults.items() if v is not use_default})
- super().__init__(params, base, foreach)
+ super().__init__(params, base, multi_tensor)
self.fns = fns
self.register_load_state_dict_post_hook(ChainOpt._restore_ecc_dtypes)
self._init_param_ecc()
@@ -1854,7 +1854,7 @@ def _step(self, group):
if "prev_lr" in group and group["prev_lr"] != group["lr"]:
utils.warn_once(
f"Learning rate changed between steps. This is an experimental feature and "
- f"only supported with foreach=True (currently foreach={group['foreach']})."
+ f"only supported with multi_tensor=True (currently multi_tensor={group['multi_tensor']})."
)
group["base_lr"] = group["lr"]
@@ -1890,7 +1890,7 @@ def _step_inner(self, group):
group["step"] = state["step"] = step = step + 1
group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
- if not group["foreach"] or len(p) == 1:
+ if not group["multi_tensor"] or len(p) == 1:
for param, grad in zip(p, g):
chain(self.state_, group, [grad], [param], *self.fns)
group["caution"] = caution
@@ -1936,6 +1936,7 @@ def default(a, b):
scale_by_laprop.get_fn(): update_by_laprop, #
scale_by_adopt.get_fn(): update_by_adopt, #
scale_by_ademamix.get_fn(): update_by_ademamix, #
+ scale_by_psgd_pro.get_fn(): update_by_psgd_pro, #
}
_scale_to_update_map_inv = {
update_by_delayed_psgd.get_fn(): scale_by_delayed_psgd, #
@@ -1947,6 +1948,7 @@ def default(a, b):
update_by_laprop.get_fn(): scale_by_laprop, #
update_by_adopt.get_fn(): scale_by_adopt, #
update_by_ademamix.get_fn(): scale_by_ademamix, #
+ update_by_psgd_pro.get_fn(): scale_by_psgd_pro, #
}
@@ -1961,7 +1963,7 @@ class BaseOpt(ChainOpt):
promote: bool = False
Whether to promote the gradients to fp32 before applying the optimizer
Improves update quality for low-precision parameters, but increases costs
- Compiling the optimizer step would reduce memory and compute. Alternatively, `foreach=False` decreases memory at the cost of runtime
+ Compiling the optimizer step would reduce memory and compute. Alternatively, `multi_tensor=False` decreases memory at the cost of runtime
gradient_clipping: str_or_fn = None
The function to use for clipping the incoming gradients, before any other transformations.
@@ -1983,7 +1985,7 @@ def __init__(
self,
params,
defaults,
- foreach: bool = True,
+ multi_tensor: bool = True,
gradient_clipping: str_or_fn = None,
update_clipping: str_or_fn = None,
palm: bool = use_default,
@@ -2031,7 +2033,7 @@ def __init__(
if default(update_clipping, self.update_clipping) is not None:
fns = fns + (apply_to_idx(update_clipping, 2),)
- super().__init__(params, defaults, foreach, *fns)
+ super().__init__(params, defaults, multi_tensor, *fns)
class ScheduleFree(BaseOpt):
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 7cb8675..0eab0be 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -1305,13 +1305,13 @@ class StatefulOptimizer(torch.optim.Optimizer):
"hessian_approx",
)
- def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
+ def __init__(self, params, defaults, multi_tensor: bool = True, use_ema: bool = False):
for attr in self._INSTANCE_ATTRS:
if attr in defaults:
val = defaults.pop(attr)
if val is not use_default:
setattr(self, attr, val)
- super().__init__(params, {**defaults, "foreach": foreach})
+ super().__init__(params, {**defaults, "multi_tensor": multi_tensor})
self.use_ema = use_ema
self.mapping = {}
self.mapping_inverse = {}
@@ -2151,7 +2151,7 @@ def _compilable_update_(
caution: bool,
g: List[Optional[Tensor]],
):
- for i, (u_, g_, p_) in enumerate(zip(u, g, p)): # lr is data-dependent -> can't compile a foreach
+ for i, (u_, g_, p_) in enumerate(zip(u, g, p)): # lr is data-dependent -> can't compile a multi-tensor op
u_ = promote(u_.view_as(p_))
p32_ = promote(p_)
if caution:
@@ -2620,7 +2620,7 @@ def multi_flatten(*xs: Tuple[List[Tensor], int]):
@decorator_knowngood
-def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
+def dampen_multiple(g: List[Tensor], damp: float = 1e-9):
vs = []
gs = []
for g_ in g:
@@ -2631,8 +2631,7 @@ def dampen_multiple(g: List[Tensor], damp: float = 2**-13):
def casted_einsum(expr: str, *args: Tensor) -> Tensor:
- md = min_dtype(args)
- return compiled_einsum(expr, *[a.to(md) for a in args]).to(args[-1].dtype)
+ return compiled_einsum(expr, *[promote(a) for a in args]).to(args[-1].dtype)
@decorator_knowngood
@@ -3012,7 +3011,7 @@ def psgd_update_precond(
return None
-@decorator
+@decorator_knowngood
def psgd_pro_update_precond(
G: Tensor,
precond_lr: float,
@@ -3548,9 +3547,8 @@ def precond_grad_cached_(
):
if caution:
ea = _compilable_cautioning(grad, ea)
- md = min_dtype(list(cached_q) + [ea])
- args = [q.to(md) for q in cached_q]
- args = args + [ea.to(md)]
+ args = [promote(q) for q in cached_q]
+ args = args + [promote(ea)]
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
new = compiled_einsum(expr, *args)
if cast:
@@ -3591,16 +3589,18 @@ def psgd_precond_grad(
grad: Optional[Tensor] = None,
store_triu_as_line: bool = False,
symmetric_output: bool = False,
+ cast: bool = True,
):
if caution:
ea = _compilable_cautioning(grad, ea)
if store_triu_as_line:
preconds = line_to_triu(preconds, symmetric_output)
- md = min_dtype(list(preconds) + [ea])
- args = [q.to(md) for q in preconds]
+ args = [promote(q) for q in preconds]
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
- new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], ea.to(md))
- return new.to(ea.dtype)
+ new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea))
+ if cast:
+ return new.to(ea.dtype)
+ return new
@decorator_knowngood
@@ -3622,6 +3622,7 @@ def _compilable_fused_psgd_precond_grad(
grad=grad,
store_triu_as_line=store_triu_as_line,
symmetric_output=symmetric_output,
+ cast=False,
)
update_param_(param, precond, lr, decay, caution=False, grad=grad)
diff --git a/scripts/migrate_optimizer_state.py b/scripts/migrate_optimizer_state.py
index fb9f6e2..1eef0aa 100644
--- a/scripts/migrate_optimizer_state.py
+++ b/scripts/migrate_optimizer_state.py
@@ -234,7 +234,7 @@ def migrate(
),
optimizer_class: str = typer.Argument(
...,
- help="Optimizer class to instantiate (e.g., heavyball.ForeachAdamW)",
+ help="Optimizer class to instantiate (e.g., heavyball.AdamW)",
),
state_key: str = typer.Option(
"optimizer",
diff --git a/test/test_ademamix.py b/test/test_ademamix.py
index 847a290..49e53de 100644
--- a/test/test_ademamix.py
+++ b/test/test_ademamix.py
@@ -101,7 +101,7 @@ def test_ademamix_matches_reference_math():
alpha_warmup = 4
param = torch.nn.Parameter(initial.clone())
- optimizer = heavyball.ForeachAdEMAMix(
+ optimizer = heavyball.AdEMAMix(
[param],
lr=lr,
betas=betas,
@@ -110,7 +110,7 @@ def test_ademamix_matches_reference_math():
alpha=alpha,
beta3_warmup=beta3_warmup,
alpha_warmup=alpha_warmup,
- foreach=False,
+ multi_tensor=False,
)
for grad in grads:
@@ -149,7 +149,7 @@ def test_soap_ademamix_projects_gradients_into_eigenbasis():
torch.manual_seed(7)
param = torch.nn.Parameter(torch.randn(2, 2))
- optimizer = heavyball.ForeachSOAPAdEMAMix([param], lr=0.01, foreach=False)
+ optimizer = heavyball.SOAPAdEMAMix([param], lr=0.01, multi_tensor=False)
# First call initializes the SOAP preconditioner state without applying an update.
param.grad = torch.randn_like(param)
diff --git a/test/test_chainable_cpu.py b/test/test_chainable_cpu.py
index 124cbe4..6612028 100644
--- a/test/test_chainable_cpu.py
+++ b/test/test_chainable_cpu.py
@@ -70,46 +70,36 @@ def state_fn(_x):
# Optimizers whose chains are purely elementwise must NOT need gather
_EXPECT_NO_GATHER = {
"SGD",
- "ForeachAdamW",
- "ForeachNAdam",
- "ForeachAdEMAMix",
+ "AdamW",
+ "NAdam",
+ "AdEMAMix",
"UnscaledAdamW",
- "ForeachAdamC",
- "ForeachRMSprop",
- "ForeachSFAdamW",
- "ForeachADOPT",
- "ForeachLaProp",
- "PaLMForeachSFAdamW",
+ "AdamC",
+ "RMSprop",
+ "SFAdamW",
+ "ADOPT",
+ "LaProp",
}
# Optimizers whose chains use shape-dependent or global-reduction ops must need gather
_EXPECT_GATHER = {
- "ForeachSOAP",
- "ForeachSOAPNAdam",
- "ForeachSOAPAdEMAMix",
- "ForeachSOLP",
- "ForeachMuon",
+ "SOAP",
+ "SOAPNAdam",
+ "SOAPAdEMAMix",
+ "SOLP",
+ "Muon",
"MuonLaProp",
"OrthoLaProp",
"LaPropOrtho",
- "ForeachPSGDKron",
- "ForeachPurePSGD",
- "ForeachCachedPSGDKron",
- "ForeachDelayedPSGD",
- "ForeachCachedDelayedPSGDKron",
- "ForeachCachedNewtonPSGD",
- "NewtonHybrid2PSGDKron",
- "ForeachPSGDLRA",
- "ForeachDelayedPSGDLRA",
- "ForeachNewtonPSGDLRA",
- "NewtonHybrid2PSGDLRA",
+ "PSGDKron",
+ "PSGDLRA",
+ "PSGDPRO",
"SUDSAdamW",
"Scion",
- "ForeachSignLaProp",
+ "SignLaProp",
"MSAMLaProp",
- "PaLMForeachSOAP",
- "PrecondScheduleForeachSOAP",
- "PrecondSchedulePaLMForeachSOAP",
+ "HyperBallAdamW",
+ "MuonAdamW",
}
_SKIP_INSTANTIATE = {"SplitOpt", "SAMWrapper"}
@@ -120,7 +110,7 @@ def state_fn(_x):
@pytest.mark.parametrize("opt_name", _ALL_OPTS)
def test_needs_gather_flag(opt_name):
params = [torch.nn.Parameter(torch.randn(4, 4))]
- extra = {"max_lr": 0.0025} if opt_name == "ForeachAdamC" else {}
+ extra = {"max_lr": 0.0025} if opt_name == "AdamC" else {}
opt = getattr(heavyball, opt_name)(params, lr=1e-3, **extra)
if opt_name in _EXPECT_NO_GATHER:
assert not opt._needs_gather, f"{opt_name} should be elementwise (no gather needed)"
diff --git a/test/test_cpu_features.py b/test/test_cpu_features.py
index e241f98..82ed343 100644
--- a/test/test_cpu_features.py
+++ b/test/test_cpu_features.py
@@ -44,9 +44,9 @@ def _make_batch(
@pytest.mark.parametrize(
"opt_name",
[
- "ForeachSOAP",
+ "SOAP",
"Muon",
- "ForeachAdamW",
+ "AdamW",
],
)
def test_selected_optimizers_run_on_cpu(opt_name: str) -> None:
@@ -92,8 +92,8 @@ def test_mars_flag_changes_behavior() -> None:
model_a, data, target = _make_batch()
model_b = deepcopy(model_a)
- opt_a = heavyball.ForeachAdamW(model_a.parameters(), mars=False, warmup_steps=0)
- opt_b = heavyball.ForeachAdamW(model_b.parameters(), mars=True, warmup_steps=0)
+ opt_a = heavyball.AdamW(model_a.parameters(), mars=False, warmup_steps=0)
+ opt_b = heavyball.AdamW(model_b.parameters(), mars=True, warmup_steps=0)
init = [param.detach().clone() for param in model_a.parameters()]
@@ -112,7 +112,7 @@ def test_mars_flag_changes_behavior() -> None:
def test_sam_wrapper_requires_closure() -> None:
model = nn.Linear(4, 2)
- base = heavyball.ForeachAdamW(model.parameters())
+ base = heavyball.AdamW(model.parameters())
wrapper = heavyball.SAMWrapper(model.parameters(), wrapped_optimizer=base)
with pytest.raises(ValueError):
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 1b60e7b..1663f13 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -17,14 +17,13 @@
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
]
-_EXTRA_KWARGS = {"ForeachAdamC": {"max_lr": 0.0025}}
+_EXTRA_KWARGS = {"AdamC": {"max_lr": 0.0025}}
_MODEL_SEED = 42
_DATA_SEED = 0xABCD
# LRA builds one preconditioner over all grads, under FSDP each rank only has a subset
_FSDP_SKIP = {
- "ForeachPSGDLRA": "LRA preconditioner scope differs under FSDP",
- "ForeachDelayedPSGDLRA": "LRA preconditioner scope differs under FSDP",
+ "PSGDLRA": "LRA preconditioner scope differs under FSDP",
}
# torch.compile(dynamic=False) specializes on list length → different kernels per rank
@@ -39,16 +38,14 @@
_INTEGRATION_OPTS = [
n
for n in [
- "ForeachAdamW",
- "ForeachSOAP",
- "ForeachMuon",
- "ForeachPSGDKron",
- "ForeachPurePSGD",
+ "AdamW",
+ "SOAP",
+ "Muon",
+ "PSGDKron",
"Scion",
- "ForeachLaProp",
+ "LaProp",
"MuonLaProp",
- "ForeachSOAPNAdam",
- "ForeachCachedPSGDKron",
+ "SOAPNAdam",
]
if n in REPRESENTATIVE_OPTS and n not in _FSDP_SKIP
]
diff --git a/test/test_ecc.py b/test/test_ecc.py
index 97c0eab..1dd7364 100644
--- a/test/test_ecc.py
+++ b/test/test_ecc.py
@@ -14,13 +14,13 @@
ULP_MODES = ["bf16+8", "bf16+16", "fp16+8", "fp16+16"]
_OPTIMIZERS = [
- (heavyball.ForeachAdamW, 5e-2, {}),
- (heavyball.ForeachADOPT, 5e-2, {}),
- (heavyball.ForeachNAdam, 1e-2, {}),
- (heavyball.ForeachLaProp, 5e-2, {}),
- (heavyball.ForeachAdEMAMix, 5e-2, {"betas": (0.9, 0.999, 0.9999)}),
- (heavyball.ForeachRMSprop, 1e-2, {}),
- (heavyball.PaLMForeachSFAdamW, 1e-2, {}),
+ (heavyball.AdamW, 5e-2, {}),
+ (heavyball.ADOPT, 5e-2, {}),
+ (heavyball.NAdam, 1e-2, {}),
+ (heavyball.LaProp, 5e-2, {}),
+ (heavyball.AdEMAMix, 5e-2, {"betas": (0.9, 0.999, 0.9999)}),
+ (heavyball.RMSprop, 1e-2, {}),
+ (heavyball.SFAdamW, 1e-2, {}),
]
@@ -125,12 +125,12 @@ def test_ecc_convergence(opt_cls, lr, extra_kw, mode):
def test_param_ecc_convergence(combined):
set_torch()
data, target = _problem()
- m0, o0 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2)
+ m0, o0 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2)
losses_base = _train(m0, o0, data, target, 200)
kw = {"param_ecc": "bf16+8"}
if combined:
kw["ecc"] = "bf16+8"
- m1, o1 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, **kw)
+ m1, o1 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2)
losses_ecc = _train(m1, o1, data, target, 200)
p = list(m1.parameters())[0]
assert p.dtype == torch.bfloat16
@@ -148,7 +148,7 @@ def test_state_layout_and_invariants(mode):
cfg = ECCConfig(mode)
torch.manual_seed(42)
model = nn.Linear(32, 16, bias=False, device="cuda")
- opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, ecc=mode)
+ opt = heavyball.AdamW(model.parameters(), lr=1e-2, ecc=mode)
x = torch.randn(4, 32, device="cuda")
for _ in range(3):
model(x).sum().backward()
@@ -174,7 +174,7 @@ def test_ademamix_three_vars():
set_torch()
torch.manual_seed(42)
model = nn.Linear(64, 32, bias=False, device="cuda")
- opt = heavyball.ForeachAdEMAMix(model.parameters(), lr=5e-2, betas=(0.9, 0.999, 0.9999), ecc="bf16+8")
+ opt = heavyball.AdEMAMix(model.parameters(), lr=5e-2, betas=(0.9, 0.999, 0.9999), ecc="bf16+8")
x = torch.randn(4, 64, device="cuda")
for _ in range(10):
model(x).sum().backward()
@@ -193,7 +193,7 @@ def test_ademamix_three_vars():
def test_combined_ecc_dtypes():
set_torch()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 32, 16, 1e-2, ecc="bf16+16", param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 32, 16, 1e-2, ecc="bf16+16", param_ecc="bf16+8")
data, target = _problem(in_dim=32, out_dim=16, n=8)
_train(m, o, data, target, 100)
p = list(m.parameters())[0]
@@ -211,7 +211,7 @@ def test_shapes_and_bias():
set_torch()
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(32, 16, bias=True), nn.Linear(16, 4, bias=False)).cuda()
- opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, ecc="bf16+8")
+ opt = heavyball.AdamW(model.parameters(), lr=1e-2, ecc="bf16+8")
data, target = torch.randn(16, 32, device="cuda"), torch.randn(16, 4, device="cuda")
losses = _train(model, opt, data, target, 50)
for p in model.parameters():
@@ -225,9 +225,9 @@ def test_shapes_and_bias():
def test_foreach_false():
set_torch()
data, target = _problem(in_dim=32, out_dim=4, n=16)
- m_fe, o_fe = _model_opt(heavyball.ForeachAdamW, 32, 4, 1e-2, ecc="bf16+8", foreach=True)
+ m_fe, o_fe = _model_opt(heavyball.AdamW, 32, 4, 1e-2, ecc="bf16+8", multi_tensor=True)
losses_fe = _train(m_fe, o_fe, data, target, 50)
- m_nf, o_nf = _model_opt(heavyball.ForeachAdamW, 32, 4, 1e-2, ecc="bf16+8", foreach=False)
+ m_nf, o_nf = _model_opt(heavyball.AdamW, 32, 4, 1e-2, ecc="bf16+8", multi_tensor=False)
losses_nf = _train(m_nf, o_nf, data, target, 50)
assert losses_nf[-1] < losses_nf[0] * 0.5
assert 0.3 < losses_nf[-1] / max(losses_fe[-1], 1e-12) < 3.0
@@ -239,7 +239,7 @@ def test_param_groups():
set_torch()
torch.manual_seed(42)
m1, m2 = nn.Linear(16, 8, bias=False, device="cuda"), nn.Linear(8, 4, bias=False, device="cuda")
- opt = heavyball.ForeachAdamW(
+ opt = heavyball.AdamW(
[
{"params": m1.parameters(), "ecc": "bf16+8"},
{"params": m2.parameters()},
@@ -261,7 +261,7 @@ def test_zero_gradients():
set_torch()
torch.manual_seed(42)
p = nn.Parameter(torch.randn(16, 8, device="cuda"))
- opt = heavyball.ForeachAdamW([p], lr=1e-2, ecc="bf16+8")
+ opt = heavyball.AdamW([p], lr=1e-2, ecc="bf16+8")
for _ in range(10):
p.grad = torch.zeros_like(p)
opt.step()
@@ -276,10 +276,10 @@ def test_state_save_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt, sd_model = deepcopy(o.state_dict()), deepcopy(m.state_dict())
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
losses_after = _train(m2, o2, data, target, 10)
@@ -302,7 +302,7 @@ def _measure_peak(cls, n, lr, ecc=None, param_ecc=None, steps=3):
kw["ecc"] = ecc
if param_ecc:
kw["param_ecc"] = param_ecc
- opt = cls([p], lr=lr, **kw)
+ opt = cls([p], lr=lr)
for _ in range(steps):
p.grad = torch.randn_like(p)
@@ -326,9 +326,9 @@ def _measure_peak(cls, n, lr, ecc=None, param_ecc=None, steps=3):
@pytest.mark.parametrize(
"cls,lr",
[
- (heavyball.ForeachAdamW, 1e-3),
- (heavyball.PaLMForeachSFAdamW, 1e-2),
- (heavyball.ForeachRMSprop, 1e-2),
+ (heavyball.AdamW, 1e-3),
+ (heavyball.SFAdamW, 1e-2),
+ (heavyball.RMSprop, 1e-2),
],
ids=["AdamW", "SFAdamW", "RMSprop"],
)
@@ -346,7 +346,7 @@ def test_ecc_peak_memory(cls, lr, mode):
@pytest.mark.parametrize("combined", [False, True], ids=["param_only", "state+param"])
def test_param_ecc_peak_memory(combined):
n = 2**24
- cls, lr = heavyball.PaLMForeachSFAdamW, 1e-2
+ cls, lr = heavyball.SFAdamW, 1e-2
pre_base, peak_base = _measure_peak(cls, n, lr)
ecc = "bf16+8" if combined else None
pre_ecc, peak_ecc = _measure_peak(cls, n, lr, ecc=ecc, param_ecc="bf16+8")
@@ -357,7 +357,7 @@ def test_param_ecc_peak_memory(combined):
def test_ecc_live_path_nonzero_correction():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.ForeachAdamW, 16, 8, 5e-2, ecc="bf16+8")
+ m, o = _model_opt(heavyball.AdamW, 16, 8, 5e-2, ecc="bf16+8")
_train(m, o, data, target, 20)
p = list(m.parameters())[0]
st, ecc_keys = _ecc_keys(o, p)
@@ -369,7 +369,7 @@ def test_ecc_live_path_nonzero_correction():
def test_lerp_returns_valid_fp32():
set_torch()
- m, o = _model_opt(heavyball.ForeachAdamW, 16, 8, 5e-2, ecc="bf16+8")
+ m, o = _model_opt(heavyball.AdamW, 16, 8, 5e-2, ecc="bf16+8")
data, target = _problem()
_train(m, o, data, target, 10)
p = list(m.parameters())[0]
@@ -398,7 +398,7 @@ def test_param_ecc_merge_dims():
nn.Flatten(),
nn.Linear(64, 8, bias=False),
).cuda()
- opt = heavyball.ForeachSOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
+ opt = heavyball.SOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
conv_p = list(model.parameters())[0]
assert conv_p.dtype == torch.bfloat16
st = _flat_state(opt, conv_p)
@@ -417,7 +417,7 @@ def test_param_ecc_merge_dims():
def test_param_ecc_dtype_at_construction():
set_torch()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
p = list(m.parameters())[0]
assert p.dtype == torch.bfloat16, "param should be bf16 immediately after construction"
st = _flat_state(o, p)
@@ -431,11 +431,11 @@ def test_param_ecc_save_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt, sd_model = deepcopy(o.state_dict()), deepcopy(m.state_dict())
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
p2 = list(m2.parameters())[0]
@@ -452,7 +452,7 @@ def test_param_ecc_partial_state_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
@@ -460,7 +460,7 @@ def test_param_ecc_partial_state_restore():
for idx_state in param_state.values():
if isinstance(idx_state, dict):
idx_state.pop("param::ecc", None)
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
st2 = _flat_state(o2, list(m2.parameters())[0])
@@ -478,12 +478,12 @@ def test_param_ecc_empty_state_restore():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
sd_opt["state"] = {}
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
p2 = list(m2.parameters())[0]
@@ -508,7 +508,7 @@ def test_param_ecc_merged_view_partial_restore():
nn.Flatten(),
nn.Linear(64, 8, bias=False),
).cuda()
- opt = heavyball.ForeachSOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
+ opt = heavyball.SOAP(model.parameters(), lr=1e-3, param_ecc="bf16+8")
data = torch.randn(4, 3, 8, 8, device="cuda")
target = torch.randn(4, 8, device="cuda")
conv_p = list(model.parameters())[0]
@@ -529,7 +529,7 @@ def test_param_ecc_merged_view_partial_restore():
nn.Flatten(),
nn.Linear(64, 8, bias=False),
).cuda()
- opt2 = heavyball.ForeachSOAP(model2.parameters(), lr=1e-3, param_ecc="bf16+8")
+ opt2 = heavyball.SOAP(model2.parameters(), lr=1e-3, param_ecc="bf16+8")
model2.load_state_dict(sd_model)
opt2.load_state_dict(sd_opt)
conv_p2 = list(model2.parameters())[0]
@@ -549,7 +549,7 @@ def test_param_ecc_merged_view_partial_restore():
def test_optimizer_kwargs_not_in_param_groups():
set_torch()
p = torch.nn.Parameter(torch.randn(4, 4, device="cuda"))
- o = heavyball.ForeachAdamW([p], lr=1e-3, compile_step=True, promote=True)
+ o = heavyball.AdamW([p], lr=1e-3, compile_step=True, promote=True)
assert o.compile_step is True
assert o.promote is True
assert "compile_step" not in o.param_groups[0]
@@ -564,7 +564,7 @@ def test_param_ecc_load_order_model_before_optimizer():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
@@ -573,7 +573,7 @@ def test_param_ecc_load_order_model_before_optimizer():
if isinstance(idx_state, dict):
idx_state.pop("param::ecc", None)
# model-first: load model, then optimizer
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
m2.load_state_dict(sd_model)
o2.load_state_dict(sd_opt)
p2 = list(m2.parameters())[0]
@@ -594,7 +594,7 @@ def test_param_ecc_load_order_optimizer_before_model():
set_torch()
data, target = _problem()
- m, o = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m, o = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
_train(m, o, data, target, 10)
sd_opt = deepcopy(o.state_dict())
sd_model = deepcopy(m.state_dict())
@@ -603,7 +603,7 @@ def test_param_ecc_load_order_optimizer_before_model():
if isinstance(idx_state, dict):
idx_state.pop("param::ecc", None)
# optimizer-first: load optimizer, then model
- m2, o2 = _model_opt(heavyball.PaLMForeachSFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
+ m2, o2 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, param_ecc="bf16+8")
o2.load_state_dict(sd_opt)
m2.load_state_dict(sd_model)
p2 = list(m2.parameters())[0]
diff --git a/test/test_foreach.py b/test/test_foreach.py
index de3fa2a..06dd9f4 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -57,12 +57,12 @@ def test_foreach(
losses = [[], []]
for i in range(total_runs):
- for foreach in [True, False]:
- lss, pk = losses[int(foreach)], peaks[int(foreach)]
+ for multi_tensor in [True, False]:
+ lss, pk = losses[int(multi_tensor)], peaks[int(multi_tensor)]
torch.manual_seed(0x2131290)
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
- o = get_optim(opt, model.parameters(), lr=1e-3, foreach=foreach)
+ o = get_optim(opt, model.parameters(), lr=1e-3, multi_tensor=multi_tensor)
clean()
@@ -92,17 +92,17 @@ def test_foreach(
cutoff = warmup_runs * iterations
losses = [loss_list[cutoff:] for loss_list in losses]
- for peak_no_foreach, peak_foreach in zip(*peaks):
- assert peak_no_foreach < peak_foreach
+ for peak_single, peak_multi in zip(*peaks):
+ assert peak_single < peak_multi
- # no-foreach LRA is a different optimizer (per-parameter LRA vs global LRA),
+ # single-tensor LRA is a different optimizer (per-parameter LRA vs global LRA),
# so we only check that both converge, not that they match.
if "LRA" in opt.__name__:
return
- for loss_no_foreach, loss_foreach in zip(*losses):
- if torch.isnan(loss_no_foreach) and torch.isnan(loss_foreach):
+ for loss_single, loss_multi in zip(*losses):
+ if torch.isnan(loss_single) and torch.isnan(loss_multi):
continue
# increase error tolerance for PSGD, as we have different RNGs -> expected differences
- assert torch.allclose(loss_no_foreach, loss_foreach, rtol=0.01 if "PSGD" in opt.__name__ else 1e-5)
+ assert torch.allclose(loss_single, loss_multi, rtol=0.01 if "PSGD" in opt.__name__ else 1e-5)
diff --git a/test/test_memory_leak.py b/test/test_memory_leak.py
index f7088dc..f6c6f50 100644
--- a/test/test_memory_leak.py
+++ b/test/test_memory_leak.py
@@ -34,7 +34,7 @@ def forward(self, x):
def test_memory(
- opt: str = "NewtonHybrid2PSGDKron",
+ opt: str = "PSGDKron",
size: int = 64,
depth: int = 2,
mars: bool = False,
diff --git a/test/test_merge.py b/test/test_merge.py
index 90bafb2..d1a660f 100644
--- a/test/test_merge.py
+++ b/test/test_merge.py
@@ -18,7 +18,7 @@ def forward(self, inp):
return self.weight.mean() * inp
-@pytest.mark.parametrize("opt", ["ForeachPSGDKron"])
+@pytest.mark.parametrize("opt", ["PSGDKron"])
@pytest.mark.parametrize("size", [(16, 16, 16, 16), (4, 4, 4, 4), (512, 1, 128), (32128, 768)])
@pytest.mark.parametrize("merge,split", [(False, False), (True, False), (True, True)])
def test_merge(
diff --git a/test/test_migrate_cli.py b/test/test_migrate_cli.py
index 1554d09..7507289 100644
--- a/test/test_migrate_cli.py
+++ b/test/test_migrate_cli.py
@@ -130,7 +130,7 @@ def test_cli_migrates_legacy_checkpoint(runner, tmp_path):
"eps": 1e-8,
"weight_decay": 0.0,
"warmup_steps": 0,
- "foreach": True,
+ "multi_tensor": True,
"storage_dtype": "float32",
"mars": False,
"caution": False,
@@ -139,7 +139,7 @@ def test_cli_migrates_legacy_checkpoint(runner, tmp_path):
"update_clipping": "use_default",
"palm": "use_default",
"beta2_scale": 0.8,
- "__class__": "heavyball.ForeachAdamW",
+ "__class__": "heavyball.AdamW",
}
],
}
@@ -148,7 +148,7 @@ def test_cli_migrates_legacy_checkpoint(runner, tmp_path):
result = runner.invoke(
migrate_script.app,
- [str(checkpoint_path), "heavyball.ForeachAdamW", "--output", str(output_path)],
+ [str(checkpoint_path), "heavyball.AdamW", "--output", str(output_path)],
)
assert result.exit_code == 0, result.stderr or result.stdout
diff --git a/test/test_optimizer_cpu_smoke.py b/test/test_optimizer_cpu_smoke.py
index 9a9fa68..80b82c7 100644
--- a/test/test_optimizer_cpu_smoke.py
+++ b/test/test_optimizer_cpu_smoke.py
@@ -31,7 +31,7 @@ def _optimizer_params():
)
)
continue
- if name == "ForeachSOAPNAdam":
+ if name == "SOAPNAdam":
params.append(
pytest.param(
name,
diff --git a/test/test_param_ecc_compile.py b/test/test_param_ecc_compile.py
index 7182344..e6fc4ce 100644
--- a/test/test_param_ecc_compile.py
+++ b/test/test_param_ecc_compile.py
@@ -70,7 +70,7 @@ def _train_linear(compile_mode, rne=False, steps=50):
try:
with ctx:
model = torch.nn.Linear(64, 32, bias=False, device="cuda")
- opt = heavyball.ForeachAdamW(model.parameters(), lr=1e-2, param_ecc="bf16+8")
+ opt = heavyball.AdamW(model.parameters(), lr=1e-2, param_ecc="bf16+8")
data = torch.randn(32, 64, device="cuda")
target = torch.randn(32, 32, device="cuda")
for _ in range(steps):
diff --git a/test/test_toy_training.py b/test/test_toy_training.py
index 08859c5..0e2e83a 100644
--- a/test/test_toy_training.py
+++ b/test/test_toy_training.py
@@ -23,7 +23,7 @@ def _flatten_tensors(tensors: Iterable[torch.Tensor]):
EXTRA_KWARGS = {
- "ForeachAdamC": {"max_lr": 0.0025},
+ "AdamC": {"max_lr": 0.0025},
}
@@ -62,15 +62,15 @@ def toy_training_results(request):
sig = inspect.signature(optimizer_cls.__init__)
kwargs = dict(EXTRA_KWARGS.get(optimizer_name, {}))
- if "foreach" in sig.parameters:
- kwargs["foreach"] = True
+ if "multi_tensor" in sig.parameters:
+ kwargs["multi_tensor"] = True
if optimizer_name == "SAMWrapper":
inner_kwargs = {}
- inner_sig = inspect.signature(heavyball.ForeachAdamW.__init__)
- if "foreach" in inner_sig.parameters:
- inner_kwargs["foreach"] = True
- inner_optimizer = heavyball.ForeachAdamW(param_list, **inner_kwargs)
+ inner_sig = inspect.signature(heavyball.AdamW.__init__)
+ if "multi_tensor" in inner_sig.parameters:
+ inner_kwargs["multi_tensor"] = True
+ inner_optimizer = heavyball.AdamW(param_list, **inner_kwargs)
optimizer = optimizer_cls(param_list, wrapped_optimizer=inner_optimizer, **kwargs)
else:
optimizer = optimizer_cls(param_list, **kwargs)
diff --git a/test/utils.py b/test/utils.py
index 800c2fc..906526a 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -13,17 +13,11 @@
# AdEMAMix variants require 3 betas, SplitOpt requires dict param specs,
# SAMWrapper and Newton variants require closures
_SKIP_GET_OPTIM = {
- "ForeachAdEMAMix",
- "ForeachSOAPAdEMAMix",
+ "AdEMAMix",
+ "SOAPAdEMAMix",
"SOAPAdEMAMix",
"SplitOpt",
"SAMWrapper",
- "ForeachCachedNewtonPSGD",
- "NewtonHybrid2PSGDKron",
- "ForeachNewtonPSGDLRA",
- "NewtonHybrid2PSGDLRA",
- "NewtonPSGDLRA",
- "NewtonPSGDKron",
}
@@ -40,7 +34,7 @@ def _fn_key(f):
def _deduplicate_by_chain(names):
"""Keep one optimizer per unique chain of functions.
- Two optimizers that differ only by foreach=True/False have identical
+ Two optimizers that differ only by multi_tensor=True/False have identical
chains and test the same code paths — keep whichever appears first.
"""
seen = set()
From 4c9ac161d345f8403801b9922527db048350c40a Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Sat, 28 Mar 2026 18:34:58 +0000
Subject: [PATCH 04/24] higher psgd precision
---
heavyball/utils.py | 30 +++++++++++-------------------
1 file changed, 11 insertions(+), 19 deletions(-)
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 0eab0be..1114b4b 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -2363,7 +2363,7 @@ def init_Q_exprs(
@decorator_knowngood
def psgd_balance_Q(Q):
- norms = [promote(q.norm(float("inf"))).log() for q in Q]
+ norms = [promote(q.abs().max()).log() for q in Q]
geometric_mean = sum([n for n in norms]) / len(Q)
for q, n in zip(Q, norms):
q *= (geometric_mean - n).exp()
@@ -2726,7 +2726,7 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
Adapted from @evanatyourservice
"""
if max_abs is None:
- max_abs = A.norm(float("inf")).clamp(min=1e-8)
+ max_abs = A.abs().max().clamp(min=1e-8)
# cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size
@@ -2794,15 +2794,16 @@ def _mv(x):
@decorator_knowngood
def procrustes_step(Q: Tensor, max_step_size: float = 1 / 8) -> None:
- R = (Q.T - Q).contiguous()
+ Q_ = promote(Q)
+ R = (Q_.T - Q_).contiguous()
R_norm = max_singular_value(R, power_iter=2) + torch.finfo(R.dtype).smallest_normal
R = R / R_norm
- RQ = R @ Q
+ RQ = R @ Q_
RRQ = R @ RQ
tr_RQ = RQ.diagonal().sum()
tr_RRQ = RRQ.diagonal().sum()
a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, max=max_step_size), max_step_size)
- Q.add_(a * (RQ + 0.5 * a * RRQ))
+ copy_stochastic_(Q, Q_ + a * (RQ + 0.5 * a * RRQ))
@decorator_knowngood
@@ -3031,7 +3032,7 @@ def psgd_pro_update_precond(
total_numel = G.numel()
for q, exprG, lb_state in zip(Q, exprGs, running_lower_bound):
- term1 = promote(compiled_einsum(exprG, Pg, Pg))
+ term1 = compiled_einsum(exprG, Pg, Pg)
q_ = promote(q)
if q.ndim < 2:
@@ -3159,7 +3160,7 @@ def _psgd_precond_update_(
q = promote(oq)
if update.ndim < 2:
- lb = update.norm(float("inf"))
+ lb = update.abs().max()
else:
lb = max_singular_value(update, power_iter=power_iter)
update = promote(update)
@@ -3543,17 +3544,13 @@ def precond_grad_cached_(
cached_q: List[Tensor],
caution: bool = False,
grad: Optional[Tensor] = None,
- cast: bool = True,
):
if caution:
ea = _compilable_cautioning(grad, ea)
args = [promote(q) for q in cached_q]
args = args + [promote(ea)]
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
- new = compiled_einsum(expr, *args)
- if cast:
- return new.to(ea.dtype)
- return new
+ return compiled_einsum(expr, *args)
TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
@@ -3561,7 +3558,7 @@ def precond_grad_cached_(
@decorator_knowngood
def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
- precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
+ precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad)
update_param_(param, precond, lr, decay, caution=False)
@@ -3589,7 +3586,6 @@ def psgd_precond_grad(
grad: Optional[Tensor] = None,
store_triu_as_line: bool = False,
symmetric_output: bool = False,
- cast: bool = True,
):
if caution:
ea = _compilable_cautioning(grad, ea)
@@ -3597,10 +3593,7 @@ def psgd_precond_grad(
preconds = line_to_triu(preconds, symmetric_output)
args = [promote(q) for q in preconds]
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
- new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea))
- if cast:
- return new.to(ea.dtype)
- return new
+ return compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea))
@decorator_knowngood
@@ -3622,7 +3615,6 @@ def _compilable_fused_psgd_precond_grad(
grad=grad,
store_triu_as_line=store_triu_as_line,
symmetric_output=symmetric_output,
- cast=False,
)
update_param_(param, precond, lr, decay, caution=False, grad=grad)
From 8526f7b118ef16480ccea93b6c94ce9017b4f713 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Sat, 28 Mar 2026 21:15:18 +0000
Subject: [PATCH 05/24] init compile step
---
heavyball/chainable.py | 27 +++++++++++--
heavyball/utils.py | 2 +
test/test_compile_step.py | 83 +++++++++++++++++++++++++++++++++++++++
3 files changed, 109 insertions(+), 3 deletions(-)
create mode 100644 test/test_compile_step.py
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index 6dbf1f9..bdad159 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -1754,6 +1754,8 @@ def __init__(self, params, defaults, multi_tensor: bool, *fns):
base.update({k: v for k, v in defaults.items() if v is not use_default})
super().__init__(params, base, multi_tensor)
self.fns = fns
+ if self.compile_step:
+ self._chain = torch.compile(self._chain, fullgraph=True)
self.register_load_state_dict_post_hook(ChainOpt._restore_ecc_dtypes)
self._init_param_ecc()
@@ -1890,17 +1892,36 @@ def _step_inner(self, group):
group["step"] = state["step"] = step = step + 1
group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
+ if isinstance(step, torch.Tensor):
+ _orig_floats = {}
+ for k, v in group.items():
+ if isinstance(v, float):
+ _orig_floats[k] = v
+ group[k] = torch.tensor(v, dtype=torch.float64, device=step.device)
+ elif isinstance(v, tuple) and any(isinstance(x, float) for x in v):
+ _orig_floats[k] = v
+ group[k] = tuple(
+ torch.tensor(x, dtype=torch.float64, device=step.device) if isinstance(x, float) else x
+ for x in v
+ )
+
if not group["multi_tensor"] or len(p) == 1:
for param, grad in zip(p, g):
- chain(self.state_, group, [grad], [param], *self.fns)
- group["caution"] = caution
+ self._chain(group, [grad], [param], caution)
else:
- chain(self.state_, group, g, p, *self.fns)
+ self._chain(group, g, p, caution)
+
+ if isinstance(step, torch.Tensor):
+ group.update(_orig_floats)
group["caution"] = caution
group["lr"] = group["prev_lr"]
group["step"] = None
+ def _chain(self, group, g, p, caution):
+ chain(self.state_, group, g, p, *self.fns)
+ group["caution"] = caution
+
str_or_fn = Union[str, callable, None, Literal[use_default]]
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 1114b4b..d0154b5 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -988,6 +988,8 @@ def scalar_guard(*args):
out.append(torch.empty((), dtype=promote(ref.dtype), device=ref.device).fill_(x))
elif isinstance(x, int):
out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
+ elif isinstance(x, Tensor) and x.is_floating_point() and x.ndim == 0:
+ out.append(x.to(dtype=promote(ref.dtype)))
else:
out.append(x)
if len(xs) == 1:
diff --git a/test/test_compile_step.py b/test/test_compile_step.py
new file mode 100644
index 0000000..14b8f80
--- /dev/null
+++ b/test/test_compile_step.py
@@ -0,0 +1,83 @@
+import inspect
+
+import pytest
+import torch
+
+import heavyball
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+EXTRA_KWARGS = {
+ "AdamC": {"max_lr": 0.01},
+}
+
+
+def _optimizer_params():
+ seen = set()
+ params = []
+ for name in heavyball.__all__:
+ if not hasattr(heavyball, name):
+ continue
+ obj = getattr(heavyball, name)
+ if not inspect.isclass(obj):
+ continue
+ if not issubclass(obj, torch.optim.Optimizer):
+ continue
+ ident = id(obj)
+ if ident in seen:
+ continue
+ seen.add(ident)
+ if name == "SplitOpt":
+ params.append(
+ pytest.param(name, obj, id=name, marks=pytest.mark.skip(reason="SplitOpt requires dict specs"))
+ )
+ continue
+ params.append(pytest.param(name, obj, id=name))
+ return params
+
+
+def _make_model():
+ return torch.nn.Sequential(
+ torch.nn.Linear(8, 16),
+ torch.nn.Tanh(),
+ torch.nn.Linear(16, 4),
+ ).to(DEVICE)
+
+
+def _run_steps(model, optimizer, n=5, seed=0xDEADBEEF):
+ torch.manual_seed(seed)
+ for _ in range(n):
+ def closure():
+ optimizer.zero_grad(set_to_none=True)
+ data = torch.randn(4, 8, device=DEVICE)
+ target = torch.randn(4, 4, device=DEVICE)
+ loss = torch.nn.functional.mse_loss(model(data), target)
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+
+@pytest.mark.parametrize("opt_name,opt_cls", _optimizer_params())
+def test_compile_step_matches_eager(opt_name, opt_cls):
+ """compile_step=True must produce identical parameters to compile_step=False."""
+ sig = inspect.signature(opt_cls.__init__)
+ if "compile_step" not in sig.parameters:
+ pytest.skip("optimizer does not accept compile_step")
+
+ kwargs = dict(EXTRA_KWARGS.get(opt_name, {}))
+
+ torch.manual_seed(0xDEADBEEF)
+ model_ref = _make_model()
+ model_test = _make_model()
+ model_test.load_state_dict(model_ref.state_dict())
+
+ opt_ref = opt_cls(model_ref.parameters(), compile_step=False, **kwargs)
+ opt_test = opt_cls(model_test.parameters(), compile_step=True, **kwargs)
+
+ _run_steps(model_ref, opt_ref)
+ _run_steps(model_test, opt_test)
+
+ for p_ref, p_test in zip(model_ref.parameters(), model_test.parameters()):
+ diff = (p_ref.data - p_test.data).abs().max().item()
+ assert diff < 1e-5, f"compile_step diverged: max_diff={diff}"
From 89187c8a3ebc917250113b78cc5e4b0d7ec4ca5b Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Sun, 29 Mar 2026 17:54:59 +0000
Subject: [PATCH 06/24] compile step
---
README.md | 4 +-
examples/branched_optimizer.py | 2 +-
heavyball/__init__.py | 32 +--
heavyball/chainable.py | 331 +++++++++++++----------------
heavyball/utils.py | 101 ++-------
scripts/migrate_optimizer_state.py | 2 +-
test/test_chainable_cpu.py | 2 +-
test/test_compile_step.py | 2 +-
test/test_ecc.py | 4 +-
test/test_migrate_cli.py | 1 -
test/test_stochastic_updates.py | 58 -----
test/test_stochastic_utils_cpu.py | 16 +-
test/test_utils_cpu.py | 24 +--
test/test_utils_property.py | 3 +
test/utils.py | 1 -
15 files changed, 196 insertions(+), 387 deletions(-)
delete mode 100644 test/test_stochastic_updates.py
diff --git a/README.md b/README.md
index 56b5359..e8bd664 100644
--- a/README.md
+++ b/README.md
@@ -153,7 +153,7 @@ opt = SOAP(model.parameters(), lr=3e-3, orig_shapes=shapes)
## Building Custom Optimizers
Every built-in optimizer is a chain of `FunctionTransform`s, an API also available for building custom optimizers.
-`Branch` runs parallel transform paths with a merge function, which is useful for grafted optimizers or ensemble
+`Parallel` runs parallel transform paths with a merge function, which is useful for grafted optimizers or ensemble
updates.
```python
@@ -168,7 +168,7 @@ class GraftedAdam(C.BaseOpt):
weight_decay=0, warmup_steps=0, multi_tensor=True):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
warmup_steps=warmup_steps)
- branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft)
+ branch = C.Parallel(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft)
super().__init__(params, defaults, multi_tensor, fns=(branch,))
```
diff --git a/examples/branched_optimizer.py b/examples/branched_optimizer.py
index 001f107..73aa271 100644
--- a/examples/branched_optimizer.py
+++ b/examples/branched_optimizer.py
@@ -38,7 +38,7 @@ def __init__(
weight_decay=weight_decay,
warmup_steps=warmup_steps,
)
- branch = C.Branch(branches=[[C.scale_by_adam], [C.identity]], merge_fn=_graft)
+ branch = C.Parallel(branches=[[C.scale_by_adam], [C.identity]], merge_fn=_graft)
super().__init__(params, defaults, multi_tensor, fns=(branch,))
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index 476aac0..d760970 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -624,8 +624,6 @@ def __init__(
max_precond_dim: int = 2048, #
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
multi_tensor: bool = True,
@@ -639,7 +637,6 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
precond_grad_accum: bool = False,
compile_step: bool = C.use_default,
promote: bool = C.use_default,
@@ -683,8 +680,6 @@ def __init__(
max_precond_dim: int = 2048,
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
multi_tensor: bool = True,
@@ -698,7 +693,6 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
precond_grad_accum: bool = False,
momentum_decay: float = 4e-3,
decoupled_weight_decay: bool = False,
@@ -744,8 +738,6 @@ def __init__(
max_precond_dim: int = 2048,
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
multi_tensor: bool = True,
@@ -759,7 +751,6 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
precond_grad_accum: bool = False,
alpha: float = 2.0,
beta3_warmup: int | None = None,
@@ -853,8 +844,6 @@ def __init__(
max_precond_dim: int = 2048, #
merge_dims: bool = True,
precondition_1d: bool = False,
- normalize_grads: bool = False,
- correct_bias: bool = True,
warmup_steps: int = 0,
split: bool = False,
multi_tensor: bool = True,
@@ -868,7 +857,6 @@ def __init__(
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
storage_dtype: str = "float32",
- stochastic_schedule: bool = False,
compile_step: bool = C.use_default,
promote: bool = C.use_default,
ecc: str | None = None,
@@ -978,7 +966,6 @@ class PSGDKron(C.BaseOpt):
delayed: bool = False
cached: bool = False
exp_avg_input: bool = True
- quad: bool = False
def __init__(
self,
@@ -998,7 +985,6 @@ def __init__(
store_triu_as_line: bool = True,
multi_tensor: bool = True,
q_dtype="float32",
- stochastic_schedule: bool = False,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -1008,11 +994,9 @@ def __init__(
exp_avg_input: Optional[bool] = C.use_default,
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default, #
- adaptive: bool = False,
- ortho_method: Optional[str] = None, # If None, no orthogonalization
+ ortho_method: Optional[str] = None,
precond_grad_accum: bool = False,
- lower_bound_beta: float = 0.9, # 0.0 recovers pre-2.0.0 PSGD
- inverse_free: bool = C.use_default,
+ lower_bound_beta: float = 0.9,
dampening: float = 1e-9,
precond_update_power_iterations: int = 2,
# expert parameters
@@ -1034,11 +1018,6 @@ def __init__(
cached = C.default(cached, self.cached)
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
- inverse_free = C.default(inverse_free, self.quad)
- if inverse_free:
- raise ValueError(
- "inverse_free (i.e., PSGD-QUAD) is not supported at the moment. Consider using https://github.com/evanatyourservice/quad_torch"
- )
params, defaults = C._build_defaults(locals())
@@ -1087,7 +1066,6 @@ def __init__(
split: bool = False,
multi_tensor: bool = True,
q_dtype="float32",
- stochastic_schedule: bool = False,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -1116,7 +1094,6 @@ def __init__(
params, defaults = C._build_defaults(locals())
defaults["store_triu_as_line"] = False
- defaults["inverse_free"] = False
self.precond_schedule = C.default(
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
@@ -1162,7 +1139,6 @@ def __init__(
warmup_steps: int = 0,
multi_tensor: bool = True, # True: global LRA across all params. False: independent per-param LRA.
q_dtype="float32",
- stochastic_schedule: bool = False,
storage_dtype: str = "float32",
mars: bool = False,
caution: bool = False,
@@ -1171,8 +1147,8 @@ def __init__(
exp_avg_input: Optional[bool] = C.use_default,
gradient_clipping: C.str_or_fn = C.use_default,
update_clipping: C.str_or_fn = C.use_default,
- eps: float = 1e-8, #
- precond_grad_accum: bool = False, # expert parameters
+ eps: float = 1e-8,
+ precond_grad_accum: bool = False,
precond_init_scale=None,
precond_init_scale_scale: float = 1,
precond_init_scale_power: Optional[float] = None,
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index bdad159..d9cd125 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -2,7 +2,6 @@
import copy
import functools
import math
-import random
from collections.abc import Iterable as _Iterable
from typing import Iterable, List, Literal, Optional, Union
@@ -25,12 +24,6 @@ def _key_in_state(state, key):
return True
-def _inplace_guard_(state, key, template_fn):
- key_not_in_state = not _key_in_state(state, key)
- if key_not_in_state:
- template_fn()
- return key_not_in_state
-
def _guard_in_state(state, key, template_fn):
if not _key_in_state(state, key):
@@ -45,7 +38,6 @@ def __init__(self, fn, names: list[str] | None = None):
self.fn = fn
self.fn_name = self.get_fn().__name__
self.transform_idx = None
- self.is_initialized = False
self.names = names
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
@@ -55,7 +47,7 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
raise NotImplementedError
def __call__(self, state, group, update, grad, param, *args, **kwargs):
- states = [state(p) for p in param]
+ states = state if isinstance(state, list) else [state(p) for p in param]
skip_update = False
for st, a in zip(states, zip(update, grad, param, *args)):
if self.transform_idx not in st.get("is_initialized", set()):
@@ -77,9 +69,11 @@ def get_fn(self):
return self.fn.get_fn()
return self.fn
+ def _build_val_names(self):
+ self._val_names = {name: f"{self.fn_name}_{name}_{self.transform_idx}" for name in self.names}
+
def val_name(self, name):
- assert self.transform_idx is not None
- return f"{self.fn_name}_{name}_{self.transform_idx}"
+ return self._val_names[name]
def __repr__(self):
return f"{self.__class__.__name__}({self.fn}, transform_idx={self.transform_idx})"
@@ -155,7 +149,7 @@ def _sel(lst, idx):
continue
group["caution"] = caution
if fns is not None:
- u, skip = _inner_chain(state, group, _sel(update, idx), _sel(grad, idx), _sel(param, idx), *fns)
+ u, skip = _inner_chain(_sel(state, idx), group, _sel(update, idx), _sel(grad, idx), _sel(param, idx), *fns)
else:
u, skip = _sel(update, idx), False
results.append((u, skip, idx))
@@ -271,12 +265,11 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
if ecc is None:
return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
- states = [state(p) for p in param]
names = [self.val_name(n) for n in self.names]
- primary_vars = [[st[vn] for st in states] for vn in names]
+ primary_vars = [[st[vn] for st in state] for vn in names]
with contextlib.ExitStack() as stack:
for vn, plist in zip(names, primary_vars):
- corrs = [st[vn + "::ecc"] for st in states]
+ corrs = [st[vn + "::ecc"] for st in state]
stack.enter_context(ecc.attached(plist, corrs))
return self.fn(state, group, update, grad, param, *args, *primary_vars, **kwargs)
@@ -346,17 +339,6 @@ def __init__(self, fn, names, init_fn, skip_first: bool = True):
super().__init__(fn, names)
self.init_fn = init_fn
self.skip_first = skip_first
- self.named_to_anonymous = None
- self.anonymous_to_named = None
-
- def _map(self, state_fn, param, mapping):
- for p in param:
- state = state_fn(p)
- for name, mapped in mapping.items():
- if mapped in state:
- raise ValueError(f"Name {name} already mapped to {mapped}")
- if name in state:
- state[mapped] = state.pop(name)
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
self.init_fn(state, group, update, grad, param, **kwargs)
@@ -370,12 +352,19 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
class NoState(FunctionTransform):
+ needs_init = False
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
return self.fn(group, update, grad, param, *args, **kwargs)
class NoStateNoMultiTensor(FunctionTransform):
def __call__(self, state, group, update, grad, param, *args, **kwargs):
+ states = state if isinstance(state, list) else [state(p) for p in param]
+ for st in states:
+ if "is_initialized" not in st:
+ st["is_initialized"] = set()
+ st["is_initialized"].add(self.transform_idx)
updates = []
skip_update = False
for a in zip(update, grad, param, *args):
@@ -398,6 +387,8 @@ def _view_preserve_ecc(src, target):
class SqueezeGrad(FunctionTransform):
+ needs_init = False
+
def __call__(self, state, group, update, grad, param, *args, **kwargs):
original_shapes = [u.shape for u in update]
update = [u.squeeze() if u.numel() > 1 else u.view(-1) for u in update]
@@ -427,6 +418,32 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
return self.fn(state, group, update, grad, param, *args, **kwargs)
+class WarmupGuard(FunctionTransform):
+ def __init__(self, fn, warmup_fns):
+ super().__init__(fn, names=[])
+ self.warmup_fns = warmup_fns
+ self.warmup_key = None
+
+ def _build_val_names(self):
+ super()._build_val_names()
+ self.warmup_key = f"_warmup_{self.transform_idx}"
+
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
+ states = state if isinstance(state, list) else [state(p) for p in param]
+ warmup_step = min(st.get(self.warmup_key, 0) for st in states)
+ if warmup_step < len(self.warmup_fns):
+ fn = self.warmup_fns[warmup_step]
+ for st, a in zip(states, zip(update, grad, param, *args)):
+ fn(st, group, *a, **kwargs)
+ st[self.warmup_key] = st.get(self.warmup_key, 0) + 1
+ raise SkipUpdate from None
+ for st in states:
+ if "is_initialized" not in st:
+ st["is_initialized"] = set()
+ st["is_initialized"].add(self.transform_idx)
+ return self.fn(state, group, update, grad, param, *args, **kwargs)
+
+
needs_full_param = functools.partial(TagGuard, needs_full_param=True)
@@ -442,6 +459,10 @@ def general_guard(*names, init_fn, skip_first: bool = True):
return functools.partial(GeneralGuard, names=names, init_fn=init_fn, skip_first=skip_first)
+def warmup_guard(*warmup_fns):
+ return functools.partial(WarmupGuard, warmup_fns=list(warmup_fns))
+
+
def no_state(fn):
return NoState(fn)
@@ -700,7 +721,7 @@ def orthogonalize_grad_to_param(group, update, grad, param):
@no_state
def update_by_schedule_free(group, update, grad, param, z):
# Compute weight_sum once per step, not per param in no-multi_tensor mode.
- if group.get("_sf_step") != group["step"]:
+ if group.get("_sf_step") is not group["step"]:
weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"]
group["weight_sum"] = group.get("weight_sum", 0) + weight
group["_sf_step"] = group["step"]
@@ -738,28 +759,21 @@ def update_by_msam(group, update, grad, param, z, exp_avg):
raise SkipUpdate from None
+def _adopt_warmup_1(state, group, update, grad, param, exp_avg, exp_avg_sq):
+ utils.scale_by_exp_avg_sq_([exp_avg_sq], [update], 0, group["eps"])
+
+
+def _adopt_warmup_2(state, group, update, grad, param, exp_avg, exp_avg_sq):
+ u = utils.promote(update)
+ easq = utils.promote(exp_avg_sq)
+ utils.copy_stochastic_(exp_avg, u / easq.sqrt().clamp_(min=group["eps"]))
+ utils.scale_by_exp_avg_sq_([exp_avg_sq], [update], utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"])
+
+
@zero_guard("exp_avg", "exp_avg_sq")
+@warmup_guard(_adopt_warmup_1, _adopt_warmup_2)
@no_state
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
- if group["step"] == 1:
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
- raise SkipUpdate from None
-
- if group["step"] == 2:
- update = utils.promote(update)
- easq = utils.promote(exp_avg_sq)
- [
- utils.copy_stochastic_(ea, u / easq_.sqrt().clamp_(min=group["eps"]))
- for ea, u, easq_ in zip(exp_avg, update, easq)
- ]
- utils.scale_by_exp_avg_sq_(
- exp_avg_sq,
- update,
- utils.beta_debias(utils.get_beta2(group), group["step"]),
- group["eps"],
- )
- raise SkipUpdate from None
-
utils.fused_adopt_(
param,
update,
@@ -777,14 +791,15 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
raise SkipUpdate from None
+def _suds_warmup_1(state, group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
+ utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
+
+
@needs_full_param
@zero_guard("exp_avg", "exp_avg_sq", "fisher_approx")
+@warmup_guard(_suds_warmup_1)
@no_state_no_multi_tensor
def scale_by_suds(group, update, grad, param, exp_avg, exp_avg_sq, fisher_approx):
- if group["step"] == 1:
- utils.copy_stochastic_(fisher_approx, update / update.norm().clamp(min=1e-8))
- raise SkipUpdate from None
-
precond_update, w = utils.eigvecs_product_rank1(update.flatten(), fisher_approx.flatten().to(update.dtype))
precond_update = utils.adam_(
exp_avg,
@@ -817,27 +832,9 @@ def scale_by_unscaled_adam(group, update, grad, param, exp_avg, exp_avg_sq):
@zero_guard("exp_avg", "exp_avg_sq")
+@warmup_guard(_adopt_warmup_1, _adopt_warmup_2)
@no_state
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
- if group["step"] == 1:
- utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group["eps"])
- raise SkipUpdate from None
-
- if group["step"] == 2:
- update = utils.promote(update)
- easq = utils.promote(exp_avg_sq)
- [
- utils.copy_stochastic_(ea, u / easq_.sqrt().clamp_(min=group["eps"]))
- for ea, u, easq_ in zip(exp_avg, update, easq)
- ]
- utils.scale_by_exp_avg_sq_(
- exp_avg_sq,
- update,
- utils.beta_debias(utils.get_beta2(group), group["step"]),
- group["eps"],
- )
- raise SkipUpdate from None
-
return utils.adopt(
update,
exp_avg_sq,
@@ -863,9 +860,7 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
)
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
- state["step"] = torch.zeros((), device=param.device, dtype=torch.float64) # torch casts int to float in ckpt load
- if group["adaptive"]:
- state["velocity"] = [torch.zeros((), device=q.device, dtype=q.dtype) for q in Q]
+ state["step"] = torch.zeros((), device=param.device, dtype=torch.float64)
if not cached:
return
@@ -907,21 +902,6 @@ def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob
)
-def precond_schedule(group, prob: Union[callable, float, None] = None, name: str = "cumulative_prob"):
- step = group["step"]
- if "precondition_frequency" in group:
- return step > 0 and step % group["precondition_frequency"] == 0
- if isinstance(step, torch.Tensor):
- utils.warn_once("Preconditioner schedule is not supported with torch.Tensor step.")
- rng = random.Random(0x172381)
- else:
- rng = random.Random(0x172381 ^ step)
- if "precond_scheduler" in group:
- return utils.precond_schedule(step, group["precond_scheduler"], rng)
- if prob is not None:
- return utils.psgd_should_update(group, prob, rng, name=name)
- raise ValueError("No preconditioner update schedule specified.")
-
@needs_full_param
@no_state_no_multi_tensor
@@ -948,7 +928,7 @@ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grok
def _store_init_norm(state, group, update, grad, param):
- state["init_norm"] = param.float().norm()
+ state["init_norm"] = param.to(_storage_dtype(group)).norm()
@needs_full_param
@@ -960,7 +940,7 @@ def update_by_hyperball(group, update, grad, param, init_norm):
def _store_std(state, group, update, grad, param):
- state["init_std"] = torch.std(param)
+ state["init_std"] = torch.std(param.to(_storage_dtype(group)))
@needs_full_param
@@ -1145,7 +1125,6 @@ def _update_psgd_precond(
param,
grad,
Q,
- velocity,
running_lower_bound,
step,
prob: Optional[callable] = None,
@@ -1160,8 +1139,6 @@ def _update_psgd_precond(
vector, hessian_vector = param.vector, param.hessian_vector
del param.vector
del param.hessian_vector
- elif group["inverse_free"]:
- vector, hessian_vector = None, grad
else:
vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
@@ -1170,7 +1147,6 @@ def _update_psgd_precond(
group["precond_lr"],
Q,
group["store_triu_as_line"],
- velocity,
utils.get_beta2(group),
group["ortho_method"],
vector,
@@ -1189,9 +1165,9 @@ def _update_psgd_precond(
if precond is not None:
return precond
if not should_use_cache or not cached:
- return None # caching adds extra ops and is not worth the overhead when we precondition at every step
+ return None
- Q_resolved = utils.line_to_triu(Q, group["inverse_free"]) if group["store_triu_as_line"] else Q
+ Q_resolved = utils.line_to_triu(Q) if group["store_triu_as_line"] else Q
for i, (c_, q_) in enumerate(zip(Q_cache, Q_resolved)):
if c_ is None:
c_ = (
@@ -1263,7 +1239,7 @@ def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs)
else:
out = utils.psgd_precond_grad(
- preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
+ preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs
)
group["caution"] = False # we already cautioned here - shouldn't do it again
return out
@@ -1282,7 +1258,7 @@ def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache):
utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs)
else:
utils.fused_psgd_precond_grad(
- preconds=Q, store_triu_as_line=group["store_triu_as_line"], symmetric_output=group["inverse_free"], **kwargs
+ preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs
)
@@ -1351,7 +1327,7 @@ def update_by_delayed_psgd_lra(group, update, grad, param, update_to_precond, U,
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
@no_state_no_multi_tensor
def scale_by_psgd(
group,
@@ -1361,20 +1337,19 @@ def scale_by_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
@no_state_no_multi_tensor
def scale_by_delayed_psgd(
group,
@@ -1384,26 +1359,22 @@ def scale_by_delayed_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
- if group.get("inverse_free", False):
- precond = None
- else:
- precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
- new = _update_psgd_precond(
- cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob
+ precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
+ _update_psgd_precond(
+ cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob
)
- return new if precond is None else precond
+ return precond
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
@no_state_no_multi_tensor
def update_by_psgd(
group,
@@ -1413,13 +1384,12 @@ def update_by_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
raise SkipUpdate from None
@@ -1439,7 +1409,7 @@ def global_clip(group, update, grad, param, clip_fn: Optional[callable] = None):
@needs_full_param
@SqueezeGrad
@PrecondGradAccumGuard
-@general_guard("Q", "Q_cache", "velocity", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
+@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_kron, skip_first=False)
@no_state_no_multi_tensor
def update_by_delayed_psgd(
group,
@@ -1449,14 +1419,13 @@ def update_by_delayed_psgd(
update_to_precond,
Q,
Q_cache,
- velocity: Optional[List[Tensor]],
running_lower_bound: List[Tensor],
step: Tensor,
cached: bool = False,
prob: Optional[callable] = None,
):
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
- _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, velocity, running_lower_bound, step, prob)
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
raise SkipUpdate from None
@@ -1680,7 +1649,7 @@ def _inner_chain(state, group, update, grad, param, *fns):
return update, skip_update
-def chain(state: Union[callable, dict], group, grad, param, *fns):
+def chain(state: list, group, grad, param, *fns):
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
ecc = ECCConfig.from_group(group, key="param_ecc")
@@ -1691,8 +1660,7 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
utils.update_param_(param, update, group["lr"], group["weight_decay"], caution=group["caution"], grad=grad)
return
- states = [state(p) for p in param]
- corrs = [st["param::ecc"] for st in states]
+ corrs = [st["param::ecc"] for st in state]
with ecc.attached(param, corrs):
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
if not skip_update and update is not None:
@@ -1731,6 +1699,7 @@ def set_indices(fns: Iterable[callable], retain: bool = True, offset: int = 0):
for ft in _walk_fns(new_fns):
if not retain or ft.transform_idx is None:
ft.transform_idx, offset = offset, offset + 1
+ ft._build_val_names()
return new_fns
@@ -1754,8 +1723,9 @@ def __init__(self, params, defaults, multi_tensor: bool, *fns):
base.update({k: v for k, v in defaults.items() if v is not use_default})
super().__init__(params, base, multi_tensor)
self.fns = fns
+ self._eager_chain = self._run_chain
if self.compile_step:
- self._chain = torch.compile(self._chain, fullgraph=True)
+ self._run_chain = torch.compile(self._run_chain, fullgraph=True)
self.register_load_state_dict_post_hook(ChainOpt._restore_ecc_dtypes)
self._init_param_ecc()
@@ -1846,10 +1816,20 @@ def fns(self, value):
self._fns = value
self._set_indices(retain=True)
self._needs_gather = any(getattr(ft, "needs_full_param", False) for ft in _walk_fns(self._fns))
+ self._transform_ids = frozenset(
+ ft.transform_idx for ft in _walk_fns(self._fns)
+ if ft.transform_idx is not None and getattr(ft, "needs_init", True)
+ )
def _set_indices(self, retain=True):
self._fns = set_indices(self.fns, retain)
+ def _find_val_name(self, name):
+ for ft in _walk_fns(self._fns):
+ if name in ft._val_names:
+ return ft._val_names[name]
+ raise KeyError(f"No transform stores '{name}'")
+
def _step(self, group):
if "base_lr" not in group:
group["base_lr"] = group["lr"]
@@ -1881,65 +1861,51 @@ def _step_inner(self, group):
for param in p:
state = self.state_(param)
- if "step" in state:
- step = state["step"]
- elif self.compile_step:
- step = utils.scalar_guard(0, param)
- else:
- step = 0
+ step = state.get("step", 0)
+ if not isinstance(step, torch.Tensor):
+ step = torch.tensor(step, dtype=torch.int64, device=param.device)
+ state["step"] = step
break
group["step"] = state["step"] = step = step + 1
group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
- if isinstance(step, torch.Tensor):
- _orig_floats = {}
- for k, v in group.items():
- if isinstance(v, float):
- _orig_floats[k] = v
- group[k] = torch.tensor(v, dtype=torch.float64, device=step.device)
- elif isinstance(v, tuple) and any(isinstance(x, float) for x in v):
- _orig_floats[k] = v
- group[k] = tuple(
- torch.tensor(x, dtype=torch.float64, device=step.device) if isinstance(x, float) else x
- for x in v
- )
-
if not group["multi_tensor"] or len(p) == 1:
for param, grad in zip(p, g):
self._chain(group, [grad], [param], caution)
else:
self._chain(group, g, p, caution)
- if isinstance(step, torch.Tensor):
- group.update(_orig_floats)
-
group["caution"] = caution
group["lr"] = group["prev_lr"]
group["step"] = None
- def _chain(self, group, g, p, caution):
- chain(self.state_, group, g, p, *self.fns)
+ def _run_chain(self, state, group, g, p, caution):
+ chain(state, group, g, p, *self.fns)
group["caution"] = caution
+ def _needs_init(self, state):
+ ids = self._transform_ids
+ return ids and any(not ids.issubset(st.get("is_initialized", set())) for st in state)
+
+ def _needs_eager(self, group, state):
+ if self._needs_init(state):
+ return True
+ if group.get("is_preconditioning", False):
+ return True
+ if group.get("ecc") or group.get("param_ecc"):
+ return True
+ return False
-str_or_fn = Union[str, callable, None, Literal[use_default]]
+ def _chain(self, group, g, p, caution):
+ state = [self.state_(pi) for pi in p]
+ fn = self._run_chain
+ if self.compile_step and self._needs_eager(group, state):
+ fn = self._eager_chain
+ fn(state, group, g, p, caution)
-def _get_clip_fn(name: str_or_fn, default_val: str_or_fn):
- name = default(name, default_val)
- if callable(name):
- return name
- elif name not in (
- "l2_clip_",
- "rmsnorm_clip_",
- "trust_region_clip_",
- "a_law_compress",
- "mu_law_compress",
- "softsign_compress",
- ):
- raise ValueError(f"Clipping function {name} not found")
- return getattr(utils, name)
+str_or_fn = Union[str, callable, None, Literal[use_default]]
def default(a, b):
@@ -1978,23 +1944,17 @@ class BaseOpt(ChainOpt):
Base Optimizer
compile_step: bool = False
- Whether to change some internals to try to make the optimizer compilable
- This does not compile the step by itself and breaks some optimizers loudly (e.g. SOAP)
+ Whether to torch.compile the optimizer step (fullgraph=True).
+ Initialization runs eagerly on the first step; subsequent steps are compiled.
promote: bool = False
- Whether to promote the gradients to fp32 before applying the optimizer
- Improves update quality for low-precision parameters, but increases costs
- Compiling the optimizer step would reduce memory and compute. Alternatively, `multi_tensor=False` decreases memory at the cost of runtime
+ Whether to promote the gradients to fp32 before applying the optimizer.
gradient_clipping: str_or_fn = None
- The function to use for clipping the incoming gradients, before any other transformations.
- This is syntactic sugar, equivalent to manually passing the function as the first element of the optimizer chain.
+ Clipping function applied to incoming gradients before any other transforms.
update_clipping: str_or_fn = None
- The function to use for clipping the outgoing updates before applying them, after all other transformations.
- This will turn off fused updates.
- This is syntactic sugar, equivalent to manually passing the function as the last element of the optimizer chain.
-
+ Clipping function applied to outgoing updates. Disables fused updates.
"""
gradient_clipping: str_or_fn = None
@@ -2059,28 +2019,29 @@ def __init__(
class ScheduleFree(BaseOpt):
def eval(self):
+ z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
+ group["train_mode"] = train_mode = not group.get("train_mode", True)
beta1 = utils.get_beta1(group)
if beta1 > 0 and not train_mode:
for p in group["params"]:
state = self.state_(p)
- if "z" in state:
- # Set p.data to x
- z = utils.promote(state["z"])
+ if z_key in state:
+ z = utils.promote(state[z_key])
p32 = utils.promote(p.data)
p32.lerp_(end=z, weight=1 - 1 / beta1)
utils.copy_stochastic_(p.data, p32)
def train(self):
+ z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
+ group["train_mode"] = train_mode = not group.get("train_mode", False)
beta1 = utils.get_beta1(group)
if beta1 > 0 and train_mode:
for p in group["params"]:
state = self.state_(p)
- if "z" in state:
- z = utils.promote(state["z"])
+ if z_key in state:
+ z = utils.promote(state[z_key])
p32 = utils.promote(p.data)
p32.lerp_(end=z, weight=1 - beta1)
utils.copy_stochastic_(p.data, p32)
@@ -2088,23 +2049,25 @@ def train(self):
class MSAM(BaseOpt):
def eval(self):
+ z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
+ group["train_mode"] = train_mode = not group.get("train_mode", True)
if not train_mode:
for p in group["params"]:
state = self.state_(p)
- if "z" in state:
+ if z_key in state:
p_copy = p.data.clone()
- utils.copy_stochastic_(p.data, state["z"])
- utils.copy_stochastic_(state["z"], p_copy)
+ utils.copy_stochastic_(p.data, state[z_key])
+ utils.copy_stochastic_(state[z_key], p_copy)
def train(self):
+ z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode")
+ group["train_mode"] = train_mode = not group.get("train_mode", False)
if train_mode:
for p in group["params"]:
state = self.state_(p)
- if "z" in state:
+ if z_key in state:
p_copy = p.data.clone()
- utils.copy_stochastic_(p.data, state["z"])
- utils.copy_stochastic_(state["z"], p_copy)
+ utils.copy_stochastic_(p.data, state[z_key])
+ utils.copy_stochastic_(state[z_key], p_copy)
diff --git a/heavyball/utils.py b/heavyball/utils.py
index d0154b5..50642ca 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -1293,7 +1293,6 @@ class StatefulOptimizer(torch.optim.Optimizer):
compile_step: bool = False
hessian_approx: bool = False
precond_schedule: Union[Callable, float, None] = None
- stochastic_schedule: bool | Literal[use_default] = use_default
finite_differences: bool = False
fallback_to_finite_differences: bool = True
_fallback_enabled: bool = False
@@ -1318,17 +1317,7 @@ def __init__(self, params, defaults, multi_tensor: bool = True, use_ema: bool =
self.mapping = {}
self.mapping_inverse = {}
- if self.stochastic_schedule is use_default:
- stochastic_schedule = None
- for group in self.param_groups:
- new = group.get("stochastic_schedule", stochastic_schedule)
- if stochastic_schedule is not None and new != stochastic_schedule:
- raise ValueError("All parameter groups must have the same stochastic_schedule.")
- stochastic_schedule = new
- self.stochastic_schedule = stochastic_schedule
-
- self.inner_group = {"stochastic_schedule": self.stochastic_schedule}
- self.precond_rng = random.Random(0x12312)
+ self.inner_group = {}
self._is_preconditioning = None
if self.hessian_approx and self.compile_step:
@@ -1340,22 +1329,24 @@ def __init__(self, params, defaults, multi_tensor: bool = True, use_ema: bool =
def _store_stats(self, state_dict: dict[str, any]):
state_dict["heavyball"] = {
"inner_group": self.inner_group,
- "precond_rng": pickle.dumps(self.precond_rng),
"use_ema": self.use_ema,
"ema_decay": self.ema_decay,
"compile_step": self.compile_step,
"hessian_approx": self.hessian_approx,
"precond_schedule": pickle.dumps(self.precond_schedule),
- "stochastic_schedule": self.stochastic_schedule,
"fallback_to_finite_differences": self.fallback_to_finite_differences,
"_fallback_enabled": self._fallback_enabled,
"hvp_interval": self.hvp_interval,
}
+ _REMOVED_STATS = frozenset({"stochastic_schedule", "precond_rng"})
+
def _load_stats(self, state_dict):
sd = state_dict.pop("heavyball", {})
for k, v in sd.items():
- if k in ("precond_rng", "precond_schedule"):
+ if k in self._REMOVED_STATS:
+ continue
+ if k in ("precond_schedule",):
v = pickle.loads(v)
setattr(self, k, v)
@@ -1603,7 +1594,7 @@ def step(self, closure: Optional[Callable] = None):
if self.precond_schedule is None:
self._is_preconditioning = False
else:
- self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule, self.precond_rng)
+ self._is_preconditioning = psgd_should_update(self.inner_group, self.precond_schedule)
loss = self._handle_closure(closure)
if self.use_ema:
@@ -2861,10 +2852,10 @@ def _approx():
@decorator_knowngood
-def _balance_to_triu(Q: "TriuOrLine", symmetric_output: bool = False):
+def _balance_to_triu(Q: "TriuOrLine"):
if isinstance(Q[0], tuple):
psgd_balance_Q([o[1] for o in Q])
- return line_to_triu(Q, symmetric_output)
+ return line_to_triu(Q)
psgd_balance_Q(Q)
return Q
@@ -2980,7 +2971,6 @@ def psgd_update_precond(
precond_lr: float,
oq: "TriuOrLine",
store_triu_as_line: bool,
- velocity: Optional[List[Tensor]],
beta2: float,
ortho_method: Optional[str],
V: Tensor,
@@ -3176,64 +3166,9 @@ def _psgd_precond_update_(
@decorator_knowngood
-def _psgd_quad_preconditioner_grad(GG: List[Tensor], Q: List[Tensor], numel: int):
- """
- I: Identity
- U: Update / gg / target
- Q: q, preconditioner
- scale: scalar scale
- ---
- U = T * scale - I
- F = I - U # = 2I - U * scale
- O = F @ Q @ F - Q
- """
- out = []
- for gg, q in zip(GG, Q):
- if gg.ndim < 2:
- scale = max(1, gg.numel()) / numel
- target = promote(gg)
- update = target * scale - 1
- out.append(q - (1 - update) * q * (1 - update))
- else:
- scale = gg.size(0) / numel
- gg = 2 * torch.eye(gg.size(0), device=gg.device, dtype=gg.dtype) - gg * scale
- update = q - casted_einsum("ab,cd,bc", gg, gg, q)
- out.append(update + update.T) # make matrix symmetric
- return out
@decorator
-def inverse_free_psgd_update_precond(
- G: Tensor,
- precond_lr: float,
- oq: List[Tensor],
- store_triu_as_line: bool,
- velocity: Optional[List[Tensor]],
- beta2: float,
- ortho_method: Optional[str],
- V: None,
- running_lower_bound: List[Tensor],
- lower_bount_beta: float,
- power_iter: int,
-) -> Tensor:
- """Update Kronecker product preconditioner Q with pair (V, G)."""
- assert V is None
- assert ortho_method is None
- assert velocity is None
- del V, ortho_method, velocity
-
- Q = _balance_to_triu(oq, True)
- precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
- exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
-
- G = psgd_precond_grad(G, Q)
- terms = [compiled_einsum(exprG, G, G) for exprG in exprGs]
- matmuled = _psgd_quad_preconditioner_grad(terms, Q, G.numel())
- _psgd_precond_update_(
- matmuled, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
- )
- return G
-
@decorator_knowngood
def _clip(x, norm, clip_at, eps=1e-8):
@@ -3495,15 +3430,13 @@ def triu_to_line(Q_list: List[Tensor]):
@decorator_knowngood
-def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]], symmetric_output: bool = False):
+def line_to_triu(Q_list: List[Tuple[Optional[List[int]], Tensor]]):
new = []
for shape, q in Q_list:
if shape is not None:
x, y = torch.triu_indices(*shape, device=q.device)
q_mat = torch.zeros(shape, device=q.device, dtype=q.dtype)
q_mat[x, y] = q
- if symmetric_output:
- q_mat[y, x] = q
q = q_mat
new.append(q)
return new
@@ -3518,14 +3451,10 @@ def warn_once(msg):
_warned.add(msg)
-def psgd_should_update(
- group, prob: Union[float, callable], rng: Optional[random.Random] = None, name: str = "cumulative_prob"
-):
+def psgd_should_update(group, prob: Union[float, callable], name: str = "cumulative_prob"):
group[f"{name}_prob_step"] = group.get(f"{name}_prob_step", 0) + 1
if not isinstance(prob, float):
prob = prob(group[f"{name}_prob_step"])
- if group["stochastic_schedule"]:
- return rng.random() < prob
cumulative_prob = group.get(name, 0)
group[name] = cumulative_prob + prob
return int(group[name]) > int(cumulative_prob)
@@ -3587,12 +3516,11 @@ def psgd_precond_grad(
caution: bool = False,
grad: Optional[Tensor] = None,
store_triu_as_line: bool = False,
- symmetric_output: bool = False,
):
if caution:
ea = _compilable_cautioning(grad, ea)
if store_triu_as_line:
- preconds = line_to_triu(preconds, symmetric_output)
+ preconds = line_to_triu(preconds)
args = [promote(q) for q in preconds]
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
return compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea))
@@ -3608,7 +3536,6 @@ def _compilable_fused_psgd_precond_grad(
caution,
preconds: TriuOrLine,
store_triu_as_line: bool = False,
- symmetric_output: bool = False,
):
precond = psgd_precond_grad(
ea,
@@ -3616,7 +3543,6 @@ def _compilable_fused_psgd_precond_grad(
caution=caution,
grad=grad,
store_triu_as_line=store_triu_as_line,
- symmetric_output=symmetric_output,
)
update_param_(param, precond, lr, decay, caution=False, grad=grad)
@@ -3630,11 +3556,10 @@ def fused_psgd_precond_grad(
caution,
preconds: TriuOrLine,
store_triu_as_line: bool = False,
- symmetric_output: bool = False,
):
lr, decay = scalar_guard(lr, decay, param[0])
_compilable_fused_psgd_precond_grad(
- ea, param, lr, grad, decay, caution, preconds, store_triu_as_line, symmetric_output
+ ea, param, lr, grad, decay, caution, preconds, store_triu_as_line
)
diff --git a/scripts/migrate_optimizer_state.py b/scripts/migrate_optimizer_state.py
index 1eef0aa..8931cc7 100644
--- a/scripts/migrate_optimizer_state.py
+++ b/scripts/migrate_optimizer_state.py
@@ -104,7 +104,7 @@ def walk(queue: Iterable[Any]):
stack.append(current.fn)
elif isinstance(current, functools.partial): # type: ignore[name-defined]
stack.append(current.func)
- elif isinstance(current, C.Branch):
+ elif isinstance(current, C.Parallel):
for branch in current.branches:
stack.extend(branch)
elif isinstance(current, (list, tuple)):
diff --git a/test/test_chainable_cpu.py b/test/test_chainable_cpu.py
index 6612028..64e24c6 100644
--- a/test/test_chainable_cpu.py
+++ b/test/test_chainable_cpu.py
@@ -36,7 +36,7 @@ def negate(_, __, update, ___, ____):
def merge_fn(outputs):
return [sum(vals) / len(vals) for vals in zip(*outputs)]
- branch = C.Branch([[double], [negate]], merge_fn)
+ branch = C.Parallel([[double], [negate]], merge_fn)
update = [torch.ones(2)]
grad = [torch.ones(2)]
diff --git a/test/test_compile_step.py b/test/test_compile_step.py
index 14b8f80..de210e7 100644
--- a/test/test_compile_step.py
+++ b/test/test_compile_step.py
@@ -80,4 +80,4 @@ def test_compile_step_matches_eager(opt_name, opt_cls):
for p_ref, p_test in zip(model_ref.parameters(), model_test.parameters()):
diff = (p_ref.data - p_test.data).abs().max().item()
- assert diff < 1e-5, f"compile_step diverged: max_diff={diff}"
+ assert diff < 1e-4, f"compile_step diverged: max_diff={diff}"
diff --git a/test/test_ecc.py b/test/test_ecc.py
index 1dd7364..163d4f9 100644
--- a/test/test_ecc.py
+++ b/test/test_ecc.py
@@ -130,7 +130,7 @@ def test_param_ecc_convergence(combined):
kw = {"param_ecc": "bf16+8"}
if combined:
kw["ecc"] = "bf16+8"
- m1, o1 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2)
+ m1, o1 = _model_opt(heavyball.SFAdamW, 16, 8, 1e-2, **kw)
losses_ecc = _train(m1, o1, data, target, 200)
p = list(m1.parameters())[0]
assert p.dtype == torch.bfloat16
@@ -302,7 +302,7 @@ def _measure_peak(cls, n, lr, ecc=None, param_ecc=None, steps=3):
kw["ecc"] = ecc
if param_ecc:
kw["param_ecc"] = param_ecc
- opt = cls([p], lr=lr)
+ opt = cls([p], lr=lr, **kw)
for _ in range(steps):
p.grad = torch.randn_like(p)
diff --git a/test/test_migrate_cli.py b/test/test_migrate_cli.py
index 7507289..253c6d8 100644
--- a/test/test_migrate_cli.py
+++ b/test/test_migrate_cli.py
@@ -169,7 +169,6 @@ def test_cli_migrates_legacy_checkpoint(runner, tmp_path):
assert view_state["is_initialized"] == [0]
assert "heavyball" in migrated_state
- assert migrated_state["heavyball"]["inner_group"]["stochastic_schedule"] is None
assert "Migrated checkpoint written to" in result.stdout
finally:
for name in list(sys.modules):
diff --git a/test/test_stochastic_updates.py b/test/test_stochastic_updates.py
deleted file mode 100644
index bd3f409..0000000
--- a/test/test_stochastic_updates.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import pytest
-import torch
-from torch import nn
-from utils import REPRESENTATIVE_OPTS
-
-import heavyball
-from heavyball.utils import clean, set_torch
-
-_SAVED_COMPILE_MODE = heavyball.utils.compile_mode
-heavyball.utils.compile_mode = "default"
-
-
-@pytest.fixture(autouse=True)
-def _isolate_compile_mode():
- heavyball.utils.compile_mode = "default"
- yield
- heavyball.utils.compile_mode = _SAVED_COMPILE_MODE
-
-
-PSGD_OPTS = [o for o in REPRESENTATIVE_OPTS if "PSGD" in o]
-
-
-@pytest.mark.parametrize("opt", PSGD_OPTS)
-def test_foreach(opt, size: int = 128, depth: int = 1, iterations: int = 512, outer_iterations: int = 2):
- set_torch()
-
- opt = getattr(heavyball, opt)
-
- losses = []
-
- for stochastic in [False, True]:
- print("stochastic", stochastic)
- torch.manual_seed(0x2131290)
- losses.append([])
-
- for i in range(outer_iterations):
- model = nn.Sequential(*[nn.Linear(size, size, bias=False) for _ in range(depth)]).cuda()
- o = opt(
- model.parameters(),
- lr=1e-3,
- stochastic_schedule=stochastic,
- preconditioner_update_probability=lambda step: 0.1,
- )
-
- for _ in range(iterations):
- loss = model(torch.randn((128, size), device="cuda")).square().mean()
- loss.backward()
- o.step()
- o.zero_grad()
- losses[-1].append(loss.detach())
-
- del model, o
- clean()
-
- stochastic = sum([l.item() for l in losses[1]])
- deterministic = sum([l.item() for l in losses[0]])
- print(f"{deterministic=}, {stochastic=}")
- assert not torch.isclose(torch.tensor(deterministic), torch.tensor(stochastic), rtol=0.01)
diff --git a/test/test_stochastic_utils_cpu.py b/test/test_stochastic_utils_cpu.py
index 0b16f2d..693533d 100644
--- a/test/test_stochastic_utils_cpu.py
+++ b/test/test_stochastic_utils_cpu.py
@@ -1,8 +1,6 @@
-import os
-
-os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
-
+import pytest
import torch
+from torch._dynamo import config
import heavyball
from heavyball.utils import (
@@ -13,7 +11,15 @@
stochastic_divide_with_eps_,
)
-heavyball.utils.atan2_scale = 1024.0
+config.cache_size_limit = 128
+
+
+@pytest.fixture(autouse=True)
+def _restore_atan2_scale():
+ orig = heavyball.utils.atan2_scale
+ heavyball.utils.atan2_scale = 1024.0
+ yield
+ heavyball.utils.atan2_scale = orig
def _average_stochastic_round(source: torch.Tensor, trials: int = 512) -> torch.Tensor:
diff --git a/test/test_utils_cpu.py b/test/test_utils_cpu.py
index 54786a4..621056c 100644
--- a/test/test_utils_cpu.py
+++ b/test/test_utils_cpu.py
@@ -40,6 +40,9 @@
# Ensure Torch dynamo stays disabled on CI runners without GPU support.
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
+from torch._dynamo import config
+
+config.cache_size_limit = 128
_SAVED_COMPILE_MODE = heavyball.utils.compile_mode
heavyball.utils.compile_mode = None
@@ -173,9 +176,9 @@ def test_triu_line_roundtrip_on_cpu():
torch.arange(9, dtype=torch.float32).reshape(3, 3),
]
packed = triu_to_line(tensors)
- restored = line_to_triu(packed, symmetric_output=True)
+ restored = line_to_triu(packed)
for original, rebuilt in zip(tensors, restored, strict=True):
- assert torch.allclose(rebuilt, torch.triu(original) + torch.triu(original, diagonal=1).T)
+ assert torch.allclose(rebuilt, torch.triu(original))
def test_warn_once_only_emits_single_warning(monkeypatch):
@@ -192,7 +195,7 @@ def test_warn_once_only_emits_single_warning(monkeypatch):
def test_psgd_should_update_accumulates_probability():
- group = {"stochastic_schedule": False}
+ group = {}
outcomes = [psgd_should_update(group, 0.4) for _ in range(4)]
assert outcomes[:2] == [False, False]
assert outcomes[2] is True
@@ -200,15 +203,6 @@ def test_psgd_should_update_accumulates_probability():
assert group["cumulative_prob_prob_step"] == 4
-def test_psgd_should_update_stochastic_schedule_uses_rng():
- rng = random.Random(123)
- group = {"stochastic_schedule": True}
- calls = [psgd_should_update(group, 0.5, rng=rng) for _ in range(5)]
- rng = random.Random(123)
- expected = [rng.random() < 0.5 for _ in range(5)]
- assert calls == expected
-
-
def test_stochastic_math_helpers_match_expected_results(n=1024):
torch.manual_seed(0x172893)
a = torch.arange(n).float()
@@ -226,7 +220,8 @@ def test_stochastic_math_helpers_match_expected_results(n=1024):
stochastic_add_divide_(c, b, alpha=1.0, divisor=2.0)
assert torch.allclose(c.float(), (a + b * 1) / 2)
- orig = heavyball.utils.default_division_backend
+ orig_backend = heavyball.utils.default_division_backend
+ orig_scale = heavyball.utils.atan2_scale
try:
heavyball.utils.atan2_scale = 1024
for backend in heavyball.utils.DivisionBackend:
@@ -235,7 +230,8 @@ def test_stochastic_math_helpers_match_expected_results(n=1024):
stochastic_divide_with_eps_(c, b)
assert torch.allclose(c.float(), a / b), f"Backend {backend} failed"
finally:
- heavyball.utils.default_division_backend = orig
+ heavyball.utils.default_division_backend = orig_backend
+ heavyball.utils.atan2_scale = orig_scale
def test_stochastic_math_accuracy():
diff --git a/test/test_utils_property.py b/test/test_utils_property.py
index d0c25cb..ec6c3e5 100644
--- a/test/test_utils_property.py
+++ b/test/test_utils_property.py
@@ -22,6 +22,9 @@
# Ensure torch.compile stays disabled on CPU-only CI runners.
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
+from torch._dynamo import config
+
+config.cache_size_limit = 128
heavyball.utils.compile_mode = None
diff --git a/test/utils.py b/test/utils.py
index 906526a..163445f 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -15,7 +15,6 @@
_SKIP_GET_OPTIM = {
"AdEMAMix",
"SOAPAdEMAMix",
- "SOAPAdEMAMix",
"SplitOpt",
"SAMWrapper",
}
From bf882d81d9ee411a93a93720ecfa72bd810fdffd Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Sun, 29 Mar 2026 20:26:49 +0000
Subject: [PATCH 07/24] cleanup
---
benchmarks/bench_singular_values.py | 2 +-
docs/benchmark.md | 4 +-
examples/autoencoder.py | 1 -
heavyball/__init__.py | 230 ++--
heavyball/chainable.py | 7 +-
heavyball/utils.py | 4 +-
interactive/playground.py | 1596 ---------------------------
interactive/static/init-globals.js | 2 -
interactive/static/node-editor.js | 570 ----------
pyproject.toml | 2 +-
scripts/migrate_optimizer_state.py | 124 ++-
test/test_compile_step.py | 31 +
test/test_ecc.py | 2 +-
test/test_migrate_cli.py | 797 ++++++++++---
test/test_param_ecc_compile.py | 2 +-
test/test_toy_training.py | 6 +-
test/utils.py | 2 +-
17 files changed, 908 insertions(+), 2474 deletions(-)
delete mode 100644 interactive/playground.py
delete mode 100644 interactive/static/init-globals.js
delete mode 100644 interactive/static/node-editor.js
diff --git a/benchmarks/bench_singular_values.py b/benchmarks/bench_singular_values.py
index c324b64..f95f735 100644
--- a/benchmarks/bench_singular_values.py
+++ b/benchmarks/bench_singular_values.py
@@ -113,7 +113,7 @@ def key_fn(r):
f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {min(rerrs):>10.6f} {max(rerrs):>10.6f} {errs:>6} {len(items):>5}"
)
else:
- print(f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {'—':>10} {'—':>10} {errs:>6} {len(items):>5}")
+ print(f"{key[0]:<8} {key[1]:<5} {key[2]:>3} {'-':>10} {'-':>10} {errs:>6} {len(items):>5}")
def main():
diff --git a/docs/benchmark.md b/docs/benchmark.md
index 5fdb320..f6ff709 100644
--- a/docs/benchmark.md
+++ b/docs/benchmark.md
@@ -82,7 +82,7 @@ informed choice.
### Case Study: Escaping the Saddle Point
An optimizer’s inability to navigate a saddle point is a classic example of a silent failure. A key test of an
-optimizer's robustness is its ability to navigate a saddle point—a region that is a minimum in one direction but a
+optimizer's robustness is its ability to navigate a saddle point - a region that is a minimum in one direction but a
maximum in another. The gradient approaches zero at the center, trapping first-order methods that rely solely on the
gradient.
@@ -95,7 +95,7 @@ optimizer may be unreliable in these settings.
## Conclusion
The HeavyBall Benchmark represents a necessary shift in how we evaluate optimizers, moving from a culture of
-score-chasing to one of deep, diagnostic understanding. These hidden failures aren’t rare edge cases—they’re a routine
+score-chasing to one of deep, diagnostic understanding. These hidden failures aren’t rare edge cases - they’re a routine
source of wasted compute and disappointing models. By making them explicit, the benchmark equips researchers and
practitioners with a detailed map of an optimizer's capabilities. By clearly identifying hidden failure modes,
practitioners can confidently choose, tune, or reconsider their optimization strategies, ultimately leading to more
diff --git a/examples/autoencoder.py b/examples/autoencoder.py
index 20350cb..615ddba 100644
--- a/examples/autoencoder.py
+++ b/examples/autoencoder.py
@@ -96,7 +96,6 @@ def main(epochs: int, batch: int, log_interval: int = 16):
lr=1e-4,
mars=True,
lower_bound_beta=0.9,
- inverse_free=True,
precond_update_power_iterations=6,
store_triu_as_line=False,
)
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index d760970..a7a69ee 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -598,7 +598,30 @@ def __init__(
)
-class SOAP(C.BaseOpt):
+class SOAPBase(C.BaseOpt):
+ use_precond_schedule: bool = False
+
+ def _build_soap_defaults(self, locals_dict, fns):
+ use_precond_schedule = C.default(locals_dict["use_precond_schedule"], self.use_precond_schedule)
+ params, defaults = C._build_defaults(locals_dict)
+ if use_precond_schedule:
+ del defaults["precondition_frequency"]
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
+ else:
+ del defaults["precond_scheduler"]
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
+ super().__init__(
+ params,
+ defaults,
+ locals_dict["multi_tensor"],
+ locals_dict["gradient_clipping"],
+ locals_dict["update_clipping"],
+ locals_dict.get("palm", False),
+ fns=fns,
+ )
+
+
+class SOAP(SOAPBase):
"""
SOAP
@@ -610,8 +633,6 @@ class SOAP(C.BaseOpt):
https://github.com/nikhilvyas/SOAP
"""
- use_precond_schedule: bool = False
-
def __init__(
self,
params,
@@ -644,30 +665,10 @@ def __init__(
param_ecc: str | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap,))
- params, defaults = C._build_defaults(locals())
-
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- multi_tensor,
- gradient_clipping,
- update_clipping,
- palm, #
- fns=(C.scale_by_soap,),
- )
-
-
-class SOAPNAdam(C.BaseOpt):
- use_precond_schedule: bool = False
+class SOAPNAdam(SOAPBase):
def __init__(
self,
params,
@@ -702,30 +703,10 @@ def __init__(
param_ecc: str | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap_nadam,))
- params, defaults = C._build_defaults(locals())
-
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- multi_tensor,
- gradient_clipping,
- update_clipping,
- palm,
- fns=(C.scale_by_soap_nadam,),
- )
-
-
-class SOAPAdEMAMix(C.BaseOpt):
- use_precond_schedule: bool = False
+class SOAPAdEMAMix(SOAPBase):
def __init__(
self,
params,
@@ -761,25 +742,7 @@ def __init__(
param_ecc: str | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
-
- params, defaults = C._build_defaults(locals())
-
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- multi_tensor,
- gradient_clipping,
- update_clipping,
- palm,
- fns=(C.scale_by_soap_ademamix,),
- )
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap_ademamix,))
class SignLaProp(C.BaseOpt):
@@ -818,7 +781,7 @@ def __init__(
)
-class SOLP(C.BaseOpt):
+class SOLP(SOAPBase):
"""
SOLP
@@ -830,8 +793,6 @@ class SOLP(C.BaseOpt):
https://github.com/nikhilvyas/SOAP
"""
- use_precond_schedule: bool = False
-
def __init__(
self,
params,
@@ -863,25 +824,7 @@ def __init__(
param_ecc: str | None = None,
**kwargs,
):
- use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
-
- params, defaults = C._build_defaults(locals())
-
- if use_precond_schedule:
- del defaults["precondition_frequency"]
- self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
- else:
- del defaults["precond_scheduler"]
- self.precond_schedule = 1 / defaults.pop("precondition_frequency")
- super().__init__(
- params,
- defaults,
- multi_tensor,
- gradient_clipping,
- update_clipping,
- palm, #
- fns=(C.scale_by_soap_laprop,),
- )
+ self._build_soap_defaults(locals(), fns=(C.scale_by_soap_laprop,))
class OrthoLaProp(C.BaseOpt):
@@ -956,7 +899,37 @@ def __init__(
)
-class PSGDKron(C.BaseOpt):
+class PSGDBase(C.BaseOpt):
+ delayed: bool = False
+ cached: bool = False
+ exp_avg_input: bool = True
+
+ def _build_psgd_defaults(self, locals_dict, fns, *, default_update_clipping=utils.trust_region_clip_, extra_defaults=None):
+ exp_avg_input = C.default(locals_dict.get("exp_avg_input", C.use_default), self.exp_avg_input)
+ update_clipping = C.default(locals_dict["update_clipping"], default_update_clipping)
+
+ locals_dict = {**locals_dict, "exp_avg_input": exp_avg_input, "update_clipping": update_clipping}
+ params, defaults = C._build_defaults(locals_dict)
+
+ self.precond_schedule = C.default(
+ defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
+ )
+
+ if extra_defaults:
+ defaults.update(extra_defaults)
+
+ super().__init__(
+ params,
+ defaults,
+ locals_dict["multi_tensor"],
+ locals_dict["gradient_clipping"],
+ update_clipping,
+ False,
+ fns=(*(C.exp_avg,) * exp_avg_input, *fns),
+ )
+
+
+class PSGDKron(PSGDBase):
"""
Originally from Evan Walters and Omead Pooladzandi, 2024
Modified under Creative Commons Attribution 4.0 International
@@ -1016,30 +989,13 @@ def __init__(
):
delayed = C.default(delayed, self.delayed)
cached = C.default(cached, self.cached)
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
-
- params, defaults = C._build_defaults(locals())
-
- self.precond_schedule = C.default(
- defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
- )
-
- super().__init__(
- params,
- defaults,
- multi_tensor,
- gradient_clipping,
- update_clipping,
- False, #
- fns=(
- *(C.exp_avg,) * exp_avg_input,
- functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),
- ),
+ self._build_psgd_defaults(
+ locals(),
+ fns=(functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached),),
)
-class PSGDPRO(C.BaseOpt):
+class PSGDPRO(PSGDBase):
"""
PSGD with Q0.5EQ1.5 (PRO/Procrustes) preconditioner update.
Solve-free alternative to standard PSGD-Kron (EQ method).
@@ -1089,31 +1045,15 @@ def __init__(
**kwargs,
):
cached = C.default(cached, self.cached)
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
- update_clipping = C.default(update_clipping, None)
-
- params, defaults = C._build_defaults(locals())
- defaults["store_triu_as_line"] = False
-
- self.precond_schedule = C.default(
- defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
- )
-
- super().__init__(
- params,
- defaults,
- multi_tensor,
- gradient_clipping,
- update_clipping,
- False,
- fns=(
- *(C.exp_avg,) * exp_avg_input,
- functools.partial(C.scale_by_psgd_pro, cached=cached),
- ),
+ self._build_psgd_defaults(
+ locals(),
+ fns=(functools.partial(C.scale_by_psgd_pro, cached=cached),),
+ default_update_clipping=None,
+ extra_defaults={"store_triu_as_line": False},
)
-class PSGDLRA(C.BaseOpt):
+class PSGDLRA(PSGDBase):
"""
Originally from Evan Walters and Omead Pooladzandi, 2024
Modified under Creative Commons Attribution 4.0 International
@@ -1164,31 +1104,18 @@ def __init__(
**kwargs,
):
delayed = C.default(delayed, self.delayed)
- exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
- update_clipping = C.default(update_clipping, utils.trust_region_clip_)
-
- params, defaults = C._build_defaults(locals())
-
- self.precond_schedule = C.default(
- defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
- )
if rank is None:
utils.warn_once(
f"{rank=}. It will be set to log2(param_count). This requires `params` to be of type list. Currently, {type(params)=}"
)
params = list(params)
- defaults["rank"] = max(1, round(math.log2(sum(p.numel() for p in params))))
- utils.warn_once(f"rank was set to {defaults['rank']}")
+ rank = max(1, round(math.log2(sum(p.numel() for p in params))))
+ utils.warn_once(f"rank was set to {rank}")
- super().__init__(
- params,
- defaults,
- multi_tensor,
- gradient_clipping,
- update_clipping,
- False, #
- fns=(*(C.exp_avg,) * exp_avg_input, C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra),
+ self._build_psgd_defaults(
+ locals(),
+ fns=(C.scale_by_delayed_psgd_lra if delayed else C.scale_by_psgd_lra,),
)
@@ -1288,4 +1215,5 @@ def zero_grad(self, set_to_none: bool = True):
capture_param_shapes = utils.capture_param_shapes
-__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer)]
+_BASE_CLASSES = {SOAPBase, PSGDBase}
+__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer) and v not in _BASE_CLASSES]
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index d9cd125..b37e43d 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -1886,7 +1886,12 @@ def _run_chain(self, state, group, g, p, caution):
def _needs_init(self, state):
ids = self._transform_ids
- return ids and any(not ids.issubset(st.get("is_initialized", set())) for st in state)
+ if not ids:
+ return False
+ all_initialized = set()
+ for st in state:
+ all_initialized.update(st.get("is_initialized", ()))
+ return not ids.issubset(all_initialized)
def _needs_eager(self, group, state):
if self._needs_init(state):
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 50642ca..44b4b28 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -3102,8 +3102,8 @@ def eigvecs_product_rank1(
using the Householder reflector with first column v. Never materializes V.
Args:
- G: shape (..., d) — gradient row(s) you want to rotate into eigenbasis.
- v: shape (d,) — current unit direction (top eigenvector of P).
+ G: shape (..., d) - gradient row(s) you want to rotate into eigenbasis.
+ v: shape (d,) - current unit direction (top eigenvector of P).
w: optional Householder vector w; pass to reuse across calls.
Returns:
diff --git a/interactive/playground.py b/interactive/playground.py
deleted file mode 100644
index bc80f62..0000000
--- a/interactive/playground.py
+++ /dev/null
@@ -1,1596 +0,0 @@
-import functools
-from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional, Tuple
-
-import gradio as gr
-import numpy as np
-import plotly.graph_objects as go
-import torch
-import torch.nn as nn
-from plotly.subplots import make_subplots
-from sklearn.decomposition import PCA
-
-import heavyball
-import heavyball.chainable as C
-
-# TensorFlow Playground inspired colors
-COLORS = {
- "bg": "#ffffff",
- "surface": "#f7f7f7",
- "primary": "#ff6f00", # Orange
- "secondary": "#0d47a1", # Blue
- "positive": "#4caf50", # Green
- "negative": "#f44336", # Red
- "text": "#212121",
- "text_light": "#757575",
- "border": "#e0e0e0", # Component categories
- "gradient": "#9c27b0", # Purple - Gradient input
- "momentum": "#2196f3", # Blue - Momentum transforms
- "scaling": "#ff5722", # Deep Orange - Scaling transforms
- "regularization": "#009688", # Teal - Regularization
- "normalization": "#795548", # Brown - Normalization
- "update": "#4caf50", # Green - Update rules
-}
-
-
-# Problem base class
-class Problem(ABC):
- """Base class for optimization problems"""
-
- @property
- @abstractmethod
- def dim(self) -> int:
- """Dimension of the parameter space"""
- pass
-
- @abstractmethod
- def init(self) -> np.ndarray:
- """Initial parameters"""
- pass
-
- @abstractmethod
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- """Compute loss given parameters"""
- pass
-
- def bounds(self) -> Optional[List[Tuple[float, float]]]:
- """Parameter bounds for visualization (None if unbounded)"""
- return None
-
- def optimum(self) -> Optional[np.ndarray]:
- """Known optimal point if available"""
- return None
-
-
-class Function2D(Problem):
- """Wrapper for 2D test functions"""
-
- def __init__(self, func, bounds, init, optimal=None):
- self.func = func
- self._bounds = bounds
- self._init = init
- self._optimal = optimal
-
- @property
- def dim(self) -> int:
- return 2
-
- def init(self) -> np.ndarray:
- return np.array(self._init)
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- if x.dim() == 1:
- return self.func(x)
- else:
- # Handle batch of parameters
- return self.func(x[0])
-
- def bounds(self):
- return self._bounds
-
- def optimum(self):
- return np.array(self._optimal) if self._optimal else None
-
-
-class MLPProblem(Problem):
- """Train a small MLP to predict |x|"""
-
- def _data(self, n_samples: int = 128, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
- def __init__(self, hidden_size=4, n_samples=128):
- self.model = nn.Sequential(
- nn.Linear(1, hidden_size),
- nn.ReLU(),
- nn.Linear(hidden_size, hidden_size),
- nn.ReLU(),
- nn.Linear(hidden_size, 1),
- )
- self.x_data, self.y_data = self._data(n_samples)
- self._dim = sum(p.numel() for p in self.model.parameters())
-
- @property
- def dim(self) -> int:
- return self._dim
-
- def init(self) -> np.ndarray:
- # Initialize model and return flattened parameters
- torch.manual_seed(42)
- for p in self.model.parameters():
- if len(p.shape) >= 2:
- nn.init.xavier_uniform_(p)
- else:
- nn.init.zeros_(p)
- return self._flatten_params()
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- # Instead of modifying model parameters, we'll use functional approach
- # Split x into weight and bias tensors
- idx = 0
- params = []
- for p in self.model.parameters():
- numel = p.numel()
- param = x[idx : idx + numel].view(p.shape)
- params.append(param)
- idx += numel
-
- # Manually compute forward pass using the parameters from x
- # For a 2-layer MLP: Linear -> ReLU -> Linear
- w1, b1, w2, b2, w3, b3 = params
-
- h = torch.nn.functional.linear(self.x_data, w1, b1)
- h = torch.nn.functional.relu(h)
- pred = torch.nn.functional.linear(h, w2, b2)
- h = torch.nn.functional.relu(h)
- pred = torch.nn.functional.linear(h, w3, b3)
-
- return nn.functional.mse_loss(pred, self.y_data)
-
- def _flatten_params(self) -> np.ndarray:
- """Flatten model parameters to vector"""
- return torch.cat([p.data.view(-1) for p in self.model.parameters()]).numpy()
-
- def _unflatten_params(self, x: torch.Tensor):
- """Set model parameters from flat vector (used for evaluation, not training)"""
- idx = 0
- for p in self.model.parameters():
- numel = p.numel()
- p.data.copy_(x[idx : idx + numel].view(p.shape))
- idx += numel
-
-
-class AbsMLP(MLPProblem):
- def _data(self, n_samples: int = 128, seed: int = 42):
- x = torch.rand((n_samples, 1)) * 4 - 2
- return x, x.abs()
-
-
-class SineMLP(MLPProblem):
- def _data(self, n_samples: int = 128, seed: int = 42):
- x = torch.rand((n_samples, 1)) * 4 - 2
- return x, x.sin()
-
-
-class ExpMLP(MLPProblem):
- def _data(self, n_samples: int = 128, seed: int = 42):
- x = torch.rand((n_samples, 1)) * 4 - 2
- return x, x.exp()
-
-
-class QuadraticBowl(Problem):
- """N-dimensional quadratic bowl"""
-
- def __init__(self, dim=10):
- self._dim = dim
- self.center = np.random.randn(dim) * 2
-
- @property
- def dim(self) -> int:
- return self._dim
-
- def init(self) -> np.ndarray:
- return np.random.randn(self._dim) * 3
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- center = torch.tensor(self.center, dtype=x.dtype, device=x.device)
- return torch.sum((x - center) ** 2)
-
- def optimum(self):
- return self.center
-
-
-class StyblinskiTang(Problem):
- """Styblinski-Tang function (separable, multimodal)"""
-
- def __init__(self, dim=4):
- self._dim = dim
-
- @property
- def dim(self) -> int:
- return self._dim
-
- def init(self) -> np.ndarray:
- return np.random.uniform(-5, 5, self._dim)
-
- def loss(self, x: torch.Tensor) -> torch.Tensor:
- return 0.5 * torch.sum(x**4 - 16 * x**2 + 5 * x)
-
- def bounds(self):
- return [(-5, 5)] * self._dim
-
- def optimum(self):
- # Global minimum at x_i = -2.903534 for all i
- return np.full(self._dim, -2.903534)
-
-
-# Test functions
-PROBLEMS = { # 2D Problems
- "Simple Bowl (2D)": Function2D(
- func=lambda x: (x[0] - 1) ** 2 + (x[1] - 2) ** 2,
- bounds=[(-3, 5), (-2, 6)],
- init=[-1.0, 0.0],
- optimal=[1.0, 2.0],
- ),
- "Ravine (2D)": Function2D(
- func=lambda x: (x[0] - 1) ** 2 + (x[1] - 2) ** 100,
- bounds=[(-3, 5), (-2, 6)],
- init=[-1.0, 1.0],
- optimal=[1.0, 2.0],
- ),
- "Rosenbrock (2D)": Function2D(
- func=lambda x: (1 - x[0]) ** 2 + 100 * (x[1] - x[0] ** 2) ** 2,
- bounds=[(-2, 2), (-1, 3)],
- init=[-1.0, 1.0],
- optimal=[1.0, 1.0],
- ),
- "Himmelblau (2D)": Function2D(
- func=lambda x: (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2,
- bounds=[(-5, 5), (-5, 5)],
- init=[0.0, 0.0],
- optimal=[3.0, 2.0],
- ),
- "Beale (2D)": Function2D(
- func=lambda x: (
- (1.5 - x[0] + x[0] * x[1]) ** 2
- + (2.25 - x[0] + x[0] * x[1] ** 2) ** 2
- + (2.625 - x[0] + x[0] * x[1] ** 3) ** 2
- ),
- bounds=[(-4.5, 4.5), (-4.5, 4.5)],
- init=[1.0, 1.0],
- optimal=[3.0, 0.5],
- ), # High-dimensional Problems
- "MLP abs(x) (13D)": AbsMLP(hidden_size=4, n_samples=128),
- "MLP sin(x) (13D)": SineMLP(hidden_size=4, n_samples=128),
- "MLP exp(x) (13D)": ExpMLP(hidden_size=4, n_samples=128),
- "Quadratic Bowl (10D)": QuadraticBowl(dim=10),
- "Styblinski-Tang (4D)": StyblinskiTang(dim=4),
- "Styblinski-Tang (8D)": StyblinskiTang(dim=8),
-}
-
-# Chainable optimizer components
-OPTIMIZER_BLOCKS = {
- "gradient": {
- "name": "🎯 Gradient Input",
- "color": COLORS["gradient"],
- "components": [
- {
- "id": "gradient_input",
- "name": "Raw Gradient",
- "icon": "∇",
- "description": "Start of the pipeline - raw gradients from backprop",
- "func": None, # This is just a placeholder
- "params": {},
- }
- ],
- },
- "momentum": {
- "name": "⚡ Momentum",
- "color": COLORS["momentum"],
- "components": [
- {
- "id": "heavyball",
- "name": "Heavy Ball",
- "icon": "🏐",
- "description": "Classic momentum: v = βv - g",
- "func": C.heavyball_momentum,
- "params": {"beta": 0.9},
- },
- {
- "id": "nesterov",
- "name": "Nesterov",
- "icon": "🚀",
- "description": "Look-ahead momentum",
- "func": C.nesterov_momentum,
- "params": {"beta": 0.9},
- },
- {
- "id": "nesterov_ema",
- "name": "Nesterov EMA",
- "icon": "📈",
- "description": "Exponential moving average variant",
- "func": C.nesterov_ema,
- "params": {"beta": 0.9},
- },
- {
- "id": "basic_ema",
- "name": "Basic EMA",
- "icon": "📉",
- "description": "Simple exponential moving average",
- "func": C.exp_avg,
- "params": {"betas": (0.9, 0.999)},
- },
- ],
- },
- "scaling": {
- "name": "📊 Adaptive Scaling",
- "color": COLORS["scaling"],
- "components": [
- {
- "id": "adam_scale",
- "name": "Adam Scaling",
- "icon": "👑",
- "description": "Adaptive per-parameter learning rates",
- "func": C.scale_by_adam,
- "params": {"betas": (0.9, 0.999), "eps": 1e-8},
- },
- {
- "id": "rmsprop_scale",
- "name": "RMSprop Scaling",
- "icon": "📉",
- "description": "Root mean square normalization",
- "func": C.scale_by_exp_avg_sq,
- "params": {"beta2": 0.999, "eps": 1e-8},
- },
- {
- "id": "adagrad_scale",
- "name": "AdaGrad Scaling",
- "icon": "📐",
- "description": "Accumulate all past gradients",
- "func": C.scale_by_exp_avg_sq,
- "params": {"beta2": 1.0, "eps": 1e-8},
- },
- ],
- },
- "regularization": {
- "name": "⚖️ Regularization",
- "color": COLORS["regularization"],
- "components": [
- {
- "id": "weight_decay",
- "name": "Weight Decay",
- "icon": "🪶",
- "description": "L2 regularization (AdamW style)",
- "func": C.weight_decay_to_ema,
- "params": {"weight_decay_to_ema": 0.01, "ema_beta": 0.999},
- },
- {
- "id": "weight_decay_init",
- "name": "Decay to Init",
- "icon": "🎯",
- "description": "Pull weights toward initialization",
- "func": C.weight_decay_to_init,
- "params": {"weight_decay_to_init": 0.01},
- },
- {
- "id": "l1_weight_decay",
- "name": "L1 Weight Decay",
- "icon": "⚡",
- "description": "L1 regularization to EMA",
- "func": C.l1_weight_decay_to_ema,
- "params": {"weight_decay_to_ema": 0.01, "ema_beta": 0.999},
- },
- ],
- },
- "normalization": {
- "name": "🔧 Gradient Processing",
- "color": COLORS["normalization"],
- "components": [
- {
- "id": "grad_clip",
- "name": "Gradient Clipping",
- "icon": "✂️",
- "description": "Clip gradient norm",
- "func": functools.partial(C.global_clip, clip_fn=heavyball.utils.l2_clip_),
- "params": {"max_norm": 1.0},
- },
- {
- "id": "sign",
- "name": "Sign SGD",
- "icon": "±",
- "description": "Use only gradient signs",
- "func": C.sign,
- "params": {"graft": True},
- },
- {
- "id": "orthogonalize",
- "name": "Orthogonalize",
- "icon": "⊥",
- "description": "Orthogonalize gradient to parameter",
- "func": C.orthogonalize_grad_to_param,
- "params": {"eps": 1e-8},
- },
- {
- "id": "orthogonalize_update",
- "name": "Orthogonalize Update",
- "icon": "⊗",
- "description": "Orthogonalize the update itself",
- "func": C.orthogonalize_update,
- "params": {},
- },
- ],
- },
- "advanced_scaling": {
- "name": "🚀 Advanced Optimizers",
- "color": COLORS["scaling"],
- "components": [
- {
- "id": "laprop",
- "name": "Laprop",
- "icon": "🌊",
- "description": "Layerwise adaptive propagation",
- "func": C.scale_by_laprop,
- "params": {"betas": (0.9, 0.999), "eps": 1e-8},
- },
- {
- "id": "adopt",
- "name": "ADOPT",
- "icon": "🎯",
- "description": "Adaptive gradient methods",
- "func": C.scale_by_adopt,
- "params": {"betas": (0.9, 0.9999), "eps": 1e-12},
- },
- ],
- },
- "preconditioning": {
- "name": "🔮 Preconditioning",
- "color": COLORS["gradient"],
- "components": [
- {
- "id": "soap",
- "name": "SOAP",
- "icon": "🧼",
- "description": "Shampoo-based preconditioning",
- "func": C.scale_by_soap,
- "params": {
- "shampoo_beta": 0.99,
- "max_precond_dim": 10000,
- "precondition_1d": False,
- "is_preconditioning": True,
- "betas": (0.9, 0.999),
- "eps": 1e-8,
- },
- },
- {
- "id": "psgd",
- "name": "PSGD",
- "icon": "🎲",
- "description": "Preconditioned SGD",
- "func": functools.partial(C.scale_by_psgd, cached=False),
- "params": {
- "precond_lr": 0.1,
- "max_size_triangular": 1024,
- "precondition_frequency": 10,
- "adaptive": False,
- "store_triu_as_line": True,
- "q_dtype": "float32",
- "inverse_free": False,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "min_ndim_triangular": 2,
- "memory_save_mode": None,
- "dampening": 1.0,
- "is_preconditioning": True,
- "betas": (0.9, 0.999),
- "eps": 1e-8,
- "ortho_method": "qr",
- "lower_bound_beta": 0.999,
- "precond_update_power_iterations": 1,
- },
- },
- {
- "id": "psgd_lra",
- "name": "PSGD LRA",
- "icon": "📐",
- "description": "Low-rank approximation PSGD",
- "func": C.scale_by_psgd_lra,
- "params": {
- "precond_lr": 0.1,
- "rank": 4,
- "param_count": 10000,
- "precondition_frequency": 10,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "q_dtype": "float32",
- "is_preconditioning": True,
- "eps": 1e-8,
- "betas": (0.9, 0.999),
- },
- },
- {
- "id": "delayed_psgd",
- "name": "Delayed PSGD",
- "icon": "⏱️",
- "description": "PSGD with delayed preconditioner updates",
- "func": functools.partial(C.scale_by_delayed_psgd, cached=False),
- "params": {
- "precond_lr": 0.1,
- "max_size_triangular": 1024,
- "precondition_frequency": 10,
- "adaptive": False,
- "store_triu_as_line": True,
- "q_dtype": "float32",
- "inverse_free": False,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "min_ndim_triangular": 2,
- "memory_save_mode": None,
- "dampening": 1.0,
- "is_preconditioning": True,
- "betas": (0.9, 0.999),
- "eps": 1e-8,
- "ortho_method": "qr",
- "lower_bound_beta": 0.999,
- "precond_update_power_iterations": 1,
- },
- },
- {
- "id": "delayed_psgd_lra",
- "name": "Delayed PSGD LRA",
- "icon": "⏰",
- "description": "Delayed low-rank PSGD",
- "func": C.scale_by_delayed_psgd_lra,
- "params": {
- "precond_lr": 0.1,
- "rank": 4,
- "param_count": 10000,
- "precondition_frequency": 10,
- "precond_init_scale": 1.0,
- "precond_init_scale_scale": 0.0,
- "precond_init_scale_power": 1.0,
- "q_dtype": "float32",
- "is_preconditioning": True,
- "eps": 1e-8,
- "betas": (0.9, 0.999),
- },
- },
- ],
- },
- "adaptive_lr": {
- "name": "🎛️ Adaptive Learning Rate",
- "color": COLORS["regularization"],
- "components": [
- {
- "id": "d_adapt",
- "name": "D-Adaptation",
- "icon": "📈",
- "description": "Automatic learning rate adaptation",
- "func": C.scale_by_d_adaptation,
- "params": {"initial_d": 1.0},
- },
- {
- "id": "lr_adapt",
- "name": "LR Adaptation",
- "icon": "🔄",
- "description": "Learning rate adaptation",
- "func": C.scale_by_lr_adaptation,
- "params": {"initial_d": 1.0, "lr_lr": 0.1},
- },
- {
- "id": "pointwise_lr_adapt",
- "name": "Pointwise LR",
- "icon": "🎚️",
- "description": "Per-parameter learning rate",
- "func": C.scale_by_pointwise_lr_adaptation,
- "params": {"initial_d": 1.0, "lr_lr": 0.1},
- },
- ],
- },
- "special": {
- "name": "✨ Special Methods",
- "color": COLORS["update"],
- "components": [
- {
- "id": "schedule_free",
- "name": "Schedule-Free",
- "icon": "🗓️",
- "description": "No learning rate schedule needed",
- "func": C.update_by_schedule_free,
- "params": {"r": 0.0, "weight_lr_power": 2.0},
- },
- {
- "id": "msam",
- "name": "MSAM",
- "icon": "🏔️",
- "description": "Momentum SAM optimizer",
- "func": C.update_by_msam,
- "params": {"sam_step_size": 0.05},
- },
- {
- "id": "mup",
- "name": "μP Approx",
- "icon": "📏",
- "description": "Maximal update parametrization",
- "func": C.mup_approx,
- "params": {},
- },
- {
- "id": "palm_beta2",
- "name": "PALM β₂",
- "icon": "🌴",
- "description": "Dynamic β₂ scheduling for PALM",
- "func": C.palm_beta2,
- "params": {"beta2_scale": 0.8},
- },
- {
- "id": "identity",
- "name": "Identity",
- "icon": "🔄",
- "description": "Pass-through (no operation)",
- "func": C.identity,
- "params": {},
- },
- ],
- },
-}
-
-# Pre-built optimizer recipes
-RECIPES = {
- "SGD": ["gradient_input"],
- "SGD + Momentum": ["gradient_input", "heavyball"],
- "Adam": ["gradient_input", "adam_scale"],
- "AdamW": ["gradient_input", "adam_scale", "weight_decay"],
- "RMSprop": ["gradient_input", "rmsprop_scale"],
- "Nesterov SGD": ["gradient_input", "nesterov"],
- "SOAP": ["gradient_input", "soap"],
- "Laprop": ["gradient_input", "laprop"],
- "ADOPT": ["gradient_input", "adopt"],
- "Sign SGD": ["gradient_input", "sign"],
- "AdamW + Clipping": ["gradient_input", "grad_clip", "adam_scale", "weight_decay"],
- "D-Adapted Adam": ["gradient_input", "adam_scale", "d_adapt"],
- "PSGD": ["gradient_input", "psgd"],
- "Schedule-Free AdamW": ["gradient_input", "adam_scale", "weight_decay", "schedule_free"],
- "EMA SGD": ["gradient_input", "basic_ema"],
- "Orthogonal Adam": ["gradient_input", "orthogonalize_update", "adam_scale"],
- "PALM": ["gradient_input", "palm_beta2", "adam_scale"],
- "Delayed PSGD": ["gradient_input", "delayed_psgd"],
- "μP Adam": ["gradient_input", "mup", "adam_scale"],
-}
-
-
-def get_component_info(comp_id):
- """Get component info by ID"""
- for category in OPTIMIZER_BLOCKS.values():
- for comp in category["components"]:
- if comp["id"] == comp_id:
- return comp, category["color"]
- return None, "#757575"
-
-
-def create_pipeline_display(pipeline):
- """Create HTML display for the current pipeline"""
- if not pipeline or pipeline == ["gradient_input"]:
- return """
-
- Drop components here to build your optimizer pipeline
-
- """
-
- blocks_html = ""
- for i, comp_id in enumerate(pipeline):
- comp_info, color = get_component_info(comp_id)
- if comp_info:
- show_arrow = i < len(pipeline) - 1
- blocks_html += f"""
-
-
-
{comp_info["icon"]}
-
{comp_info["name"]}
-
-
- {'
→
' if show_arrow else ""}
-
- """
-
- return f"""
-
- {blocks_html}
-
- """
-
-
-def build_optimizer_from_pipeline(pipeline: List[str], params, kwargs):
- """Build optimizer from pipeline of component IDs"""
- fns = []
- opt_params = {
- "lr": 0.001,
- "step": 1, # Required for many functions
- "caution": False,
- "weight_decay": 0.0,
- **kwargs,
- }
-
- for comp_id in pipeline:
- if comp_id == "gradient_input":
- continue # Skip the input block
-
- # Find component
- comp_info, _ = get_component_info(comp_id)
- if comp_info and comp_info["func"] is not None:
- fns.append(comp_info["func"])
- # Update parameters, handling special cases
- params_to_add = comp_info["params"].copy()
-
- # Handle special parameter mappings
- if "beta" in params_to_add and "betas" not in params_to_add:
- # Convert single beta to betas tuple for functions expecting it
- beta = params_to_add.pop("beta")
- if "betas" in opt_params:
- opt_params["betas"] = (beta, opt_params["betas"][1])
- else:
- opt_params["betas"] = (beta, 0.999)
-
- # First add component defaults, then override with user-provided kwargs
- for key, value in params_to_add.items():
- if key not in kwargs: # Only use default if not provided by user
- opt_params[key] = value
-
- if not fns:
- # Default to simple gradient descent
- return C.BaseOpt(params, opt_params, fns=[C.identity])
-
- return C.BaseOpt(params, opt_params, fns=fns)
-
-
-def run_optimization(problem_name: str, pipeline: List[str], steps: int, **kwargs) -> Tuple[Any, Dict]:
- """Run optimization with the custom pipeline"""
- problem = PROBLEMS[problem_name]
- kwargs["lr"] = 10 ** kwargs.get("lr", -2)
-
- # Initialize parameters
- init = problem.init()
- x = torch.nn.Parameter(torch.tensor(init, dtype=torch.float32))
- params = [x]
-
- # Build optimizer from pipeline
- optimizer = build_optimizer_from_pipeline(pipeline, params, kwargs)
-
- # Run optimization
- trajectory = []
- losses = []
- gradients = []
-
- for i in range(steps):
- trajectory.append(x.detach().numpy().copy())
-
- def closure():
- optimizer.zero_grad()
- loss = problem.loss(x)
- loss.backward()
- if x.grad is not None:
- gradients.append(x.grad.detach().numpy().copy())
- return loss
-
- loss = optimizer.step(closure)
- optimizer.zero_grad()
- losses.append(loss.item())
-
- trajectory = np.array(trajectory)
-
- # Create visualization
- fig = create_visualization(problem_name, trajectory, losses, gradients, pipeline)
-
- return fig, {
- "trajectory": trajectory.tolist(),
- "losses": losses,
- "final_loss": losses[-1],
- "steps_to_converge": find_convergence(losses),
- }
-
-
-def find_convergence(losses, threshold=1e-6):
- """Find when optimization converged"""
- if len(losses) < 10:
- return len(losses)
-
- for i in range(10, len(losses)):
- if abs(losses[i] - losses[i - 5]) < threshold:
- return i
- return len(losses)
-
-
-def create_visualization(problem_name, trajectory, losses, gradients, pipeline):
- """Create integrated visualization"""
- problem = PROBLEMS[problem_name]
-
- if problem.dim == 2:
- return create_2d_visualization(problem_name, trajectory, losses, gradients, pipeline)
- return create_highdim_visualization(problem_name, trajectory, losses, gradients, pipeline)
-
-
-def create_2d_visualization(problem_name, trajectory, losses, gradients, pipeline):
- """Create visualization for 2D problems"""
- problem = PROBLEMS[problem_name]
- bounds = problem.bounds()
-
- # Create subplots
- fig = make_subplots(
- rows=2,
- cols=2,
- subplot_titles=("Optimization Landscape", "Pipeline Architecture", "Loss Curve", "Learning Dynamics"),
- column_widths=[0.6, 0.4],
- row_heights=[0.6, 0.4],
- specs=[[{"type": "contour"}, {"type": "scatter"}], [{"type": "scatter"}, {"type": "scatter"}]],
- )
-
- # 1. Optimization landscape
- x = np.linspace(bounds[0][0], bounds[0][1], 100)
- y = np.linspace(bounds[1][0], bounds[1][1], 100)
- X, Y = np.meshgrid(x, y)
- Z = np.zeros_like(X)
-
- for i in range(X.shape[0]):
- for j in range(X.shape[1]):
- point = torch.tensor([X[i, j], Y[i, j]], dtype=torch.float32)
- Z[i, j] = problem.loss(point).item()
-
- # Contour plot
- fig.add_trace(
- go.Contour(
- x=x,
- y=y,
- z=Z,
- colorscale=[[0, "#e3f2fd"], [0.5, "#2196f3"], [1, "#0d47a1"]],
- showscale=False,
- contours=dict(
- start=0,
- end=Z.max(),
- size=Z.max() / 15,
- ),
- ),
- row=1,
- col=1,
- )
-
- # Add optimization path
- if len(trajectory) > 0:
- colors = np.linspace(0, 1, len(trajectory))
-
- for i in range(1, len(trajectory)):
- fig.add_trace(
- go.Scatter(
- x=[trajectory[i - 1, 0], trajectory[i, 0]],
- y=[trajectory[i - 1, 1], trajectory[i, 1]],
- mode="lines",
- line=dict(color=f"rgba(255, {int(111 * (1 - colors[i]))}, 0, {0.3 + 0.7 * colors[i]})", width=3),
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=1,
- )
-
- # Start and end points
- fig.add_trace(
- go.Scatter(
- x=[trajectory[0, 0]],
- y=[trajectory[0, 1]],
- mode="markers+text",
- marker=dict(size=12, color="#4caf50", line=dict(color="white", width=2)),
- text=["Start"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- fig.add_trace(
- go.Scatter(
- x=[trajectory[-1, 0]],
- y=[trajectory[-1, 1]],
- mode="markers+text",
- marker=dict(size=12, color="#ff6f00", line=dict(color="white", width=2)),
- text=["End"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- # 2. Pipeline visualization
- block_x = []
- block_y = []
- block_text = []
- block_colors = []
-
- for i, comp_id in enumerate(pipeline):
- block_x.append(i / (len(pipeline) - 1) if len(pipeline) > 1 else 0.5)
- block_y.append(0.5)
-
- comp_info, comp_color = get_component_info(comp_id)
- block_text.append(comp_info["icon"] if comp_info else "?")
- block_colors.append(comp_color)
-
- fig.add_trace(
- go.Scatter(
- x=block_x,
- y=block_y,
- mode="markers+text",
- marker=dict(size=40, color=block_colors, line=dict(color="white", width=2)),
- text=block_text,
- textposition="middle center",
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=2,
- )
-
- # Add arrows between blocks
- for i in range(len(pipeline) - 1):
- fig.add_annotation(
- x=block_x[i + 1],
- y=0.5,
- ax=block_x[i],
- ay=0.5,
- xref="x2",
- yref="y2",
- axref="x2",
- ayref="y2",
- showarrow=True,
- arrowhead=2,
- arrowsize=1,
- arrowwidth=2,
- arrowcolor="#ff6f00",
- row=1,
- col=2,
- )
-
- # 3. Loss curve
- fig.add_trace(
- go.Scatter(
- x=list(range(len(losses))),
- y=losses,
- mode="lines",
- line=dict(color="#ff6f00", width=3),
- fill="tozeroy",
- fillcolor="rgba(255, 111, 0, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=1,
- )
-
- # 4. Learning dynamics (gradient norm)
- if gradients:
- grad_norms = [np.linalg.norm(g) for g in gradients]
- fig.add_trace(
- go.Scatter(
- x=list(range(len(grad_norms))),
- y=grad_norms,
- mode="lines",
- line=dict(color="#2196f3", width=2),
- fill="tozeroy",
- fillcolor="rgba(33, 150, 243, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=2,
- )
-
- # Update layout
- fig.update_layout(
- height=700,
- showlegend=False,
- paper_bgcolor="white",
- plot_bgcolor="white",
- margin=dict(l=40, r=40, t=60, b=40),
- )
-
- # Update axes
- fig.update_xaxes(showgrid=True, gridcolor="#f0f0f0")
- fig.update_yaxes(showgrid=True, gridcolor="#f0f0f0")
-
- # Pipeline plot
- fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=2)
- fig.update_yaxes(showticklabels=False, showgrid=False, range=[0, 1], row=1, col=2)
-
- # Loss plot
- fig.update_xaxes(title_text="Iteration", row=2, col=1)
- fig.update_yaxes(title_text="Loss", type="log", row=2, col=1)
-
- # Gradient plot
- fig.update_xaxes(title_text="Iteration", row=2, col=2)
- fig.update_yaxes(title_text="Gradient Norm", row=2, col=2)
-
- return fig
-
-
-def create_highdim_visualization(problem_name, trajectory, losses, gradients, pipeline):
- """Create visualization for high-dimensional problems using PCA"""
- problem = PROBLEMS[problem_name]
-
- # Create subplots
- fig = make_subplots(
- rows=2,
- cols=2,
- subplot_titles=("PCA Trajectory Projection", "Pipeline Architecture", "Loss Curve", "Learning Dynamics"),
- column_widths=[0.6, 0.4],
- row_heights=[0.6, 0.4],
- specs=[[{"type": "contour"}, {"type": "scatter"}], [{"type": "scatter"}, {"type": "scatter"}]],
- )
-
- # 1. PCA projection of trajectory
- if len(trajectory) > 2:
- # Fit PCA on trajectory
- trajectory_array = np.array(trajectory)
- pca = PCA(n_components=min(2, trajectory_array.shape[1]))
- trajectory_2d = pca.fit_transform(trajectory_array)
-
- # Create loss landscape in PCA space
- # Determine bounds based on trajectory
- margin = 0.2
- x_min, x_max = trajectory_2d[:, 0].min(), trajectory_2d[:, 0].max()
- x_range = x_max - x_min
- x_min -= margin * x_range
- x_max += margin * x_range
-
- if trajectory_2d.shape[1] > 1:
- y_min, y_max = trajectory_2d[:, 1].min(), trajectory_2d[:, 1].max()
- y_range = y_max - y_min
- y_min -= margin * y_range
- y_max += margin * y_range
- else:
- y_min, y_max = -1, 1
-
- # Create grid in PCA space
- n_points = 50
- x_grid = np.linspace(x_min, x_max, n_points)
- y_grid = np.linspace(y_min, y_max, n_points)
- X_grid, Y_grid = np.meshgrid(x_grid, y_grid)
-
- # Compute losses on grid
- Z_grid = np.zeros_like(X_grid)
- mean_trajectory = pca.mean_
-
- for i in range(n_points):
- for j in range(n_points):
- # Map 2D point back to high-dimensional space
- pca_coords = np.array([X_grid[i, j], Y_grid[i, j]])
- if trajectory_2d.shape[1] == 1:
- pca_coords = pca_coords[:1] # Use only first coordinate
-
- # Reconstruct high-dimensional point
- high_dim_point = mean_trajectory + pca_coords @ pca.components_[: len(pca_coords)]
-
- # Evaluate loss
- x_tensor = torch.tensor(high_dim_point, dtype=torch.float32)
- Z_grid[i, j] = problem.loss(x_tensor).item()
-
- # Add contour plot
- fig.add_trace(
- go.Contour(
- x=x_grid,
- y=y_grid,
- z=Z_grid,
- colorscale=[[0, "#e3f2fd"], [0.5, "#2196f3"], [1, "#0d47a1"]],
- showscale=False,
- contours=dict(
- start=Z_grid.min(),
- end=Z_grid.max(),
- size=(Z_grid.max() - Z_grid.min()) / 15,
- ),
- ),
- row=1,
- col=1,
- )
-
- # Add trajectory on top of contour
- colors = np.linspace(0, 1, len(trajectory_2d))
-
- for i in range(1, len(trajectory_2d)):
- fig.add_trace(
- go.Scatter(
- x=[trajectory_2d[i - 1, 0], trajectory_2d[i, 0]],
- y=[trajectory_2d[i - 1, 1] if trajectory_2d.shape[1] > 1 else [0, 0]],
- mode="lines",
- line=dict(color=f"rgba(255, {int(111 * (1 - colors[i]))}, 0, {0.3 + 0.7 * colors[i]})", width=3),
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=1,
- )
-
- # Start and end points
- fig.add_trace(
- go.Scatter(
- x=[trajectory_2d[0, 0]],
- y=[trajectory_2d[0, 1] if trajectory_2d.shape[1] > 1 else 0],
- mode="markers+text",
- marker=dict(size=12, color="#4caf50", line=dict(color="white", width=2)),
- text=["Start"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- fig.add_trace(
- go.Scatter(
- x=[trajectory_2d[-1, 0]],
- y=[trajectory_2d[-1, 1] if trajectory_2d.shape[1] > 1 else 0],
- mode="markers+text",
- marker=dict(size=12, color="#ff6f00", line=dict(color="white", width=2)),
- text=["End"],
- textposition="top center",
- showlegend=False,
- ),
- row=1,
- col=1,
- )
-
- # Add explained variance text
- if trajectory_2d.shape[1] > 1:
- explained_var = pca.explained_variance_ratio_
- fig.add_annotation(
- x=0.5,
- y=1.1,
- xref="x domain",
- yref="y domain",
- text=f"Explained variance: PC1={explained_var[0]:.1%}, PC2={explained_var[1]:.1%}",
- showarrow=False,
- row=1,
- col=1,
- )
-
- # 2. Pipeline visualization (same as 2D)
- block_x = []
- block_y = []
- block_text = []
- block_colors = []
-
- for i, comp_id in enumerate(pipeline):
- block_x.append(i / (len(pipeline) - 1) if len(pipeline) > 1 else 0.5)
- block_y.append(0.5)
-
- comp_info, comp_color = get_component_info(comp_id)
- block_text.append(comp_info["icon"] if comp_info else "?")
- block_colors.append(comp_color)
-
- fig.add_trace(
- go.Scatter(
- x=block_x,
- y=block_y,
- mode="markers+text",
- marker=dict(size=40, color=block_colors, line=dict(color="white", width=2)),
- text=block_text,
- textposition="middle center",
- showlegend=False,
- hoverinfo="skip",
- ),
- row=1,
- col=2,
- )
-
- # Add arrows between blocks
- for i in range(len(pipeline) - 1):
- fig.add_annotation(
- x=block_x[i + 1],
- y=0.5,
- ax=block_x[i],
- ay=0.5,
- xref="x2",
- yref="y2",
- axref="x2",
- ayref="y2",
- showarrow=True,
- arrowhead=2,
- arrowsize=1,
- arrowwidth=2,
- arrowcolor="#ff6f00",
- row=1,
- col=2,
- )
-
- # 3. Loss curve
- fig.add_trace(
- go.Scatter(
- x=list(range(len(losses))),
- y=losses,
- mode="lines",
- line=dict(color="#ff6f00", width=3),
- fill="tozeroy",
- fillcolor="rgba(255, 111, 0, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=1,
- )
-
- # 4. Learning dynamics (gradient norm)
- if gradients:
- grad_norms = [np.linalg.norm(g) for g in gradients]
- fig.add_trace(
- go.Scatter(
- x=list(range(len(grad_norms))),
- y=grad_norms,
- mode="lines",
- line=dict(color="#2196f3", width=2),
- fill="tozeroy",
- fillcolor="rgba(33, 150, 243, 0.1)",
- showlegend=False,
- ),
- row=2,
- col=2,
- )
-
- # Update layout
- fig.update_layout(
- height=700,
- showlegend=False,
- paper_bgcolor="white",
- plot_bgcolor="white",
- margin=dict(l=40, r=40, t=60, b=40),
- )
-
- # Update axes
- fig.update_xaxes(showgrid=True, gridcolor="#f0f0f0")
- fig.update_yaxes(showgrid=True, gridcolor="#f0f0f0")
-
- # PCA plot
- fig.update_xaxes(title_text="First Principal Component", row=1, col=1)
- fig.update_yaxes(title_text="Second Principal Component", row=1, col=1)
-
- # Pipeline plot
- fig.update_xaxes(showticklabels=False, showgrid=False, row=1, col=2)
- fig.update_yaxes(showticklabels=False, showgrid=False, range=[0, 1], row=1, col=2)
-
- # Loss plot
- fig.update_xaxes(title_text="Iteration", row=2, col=1)
- fig.update_yaxes(title_text="Loss", type="log", row=2, col=1)
-
- # Gradient plot
- fig.update_xaxes(title_text="Iteration", row=2, col=2)
- fig.update_yaxes(title_text="Gradient Norm", row=2, col=2)
-
- return fig
-
-
-def create_app():
- with gr.Blocks(theme=gr.themes.Base()) as app:
- # Custom CSS
- gr.HTML("""
-
- """)
-
- # Header
- gr.Markdown("""
- # 🧩 HeavyBall Chainable Optimizer Playground
-
- ### Build custom optimizers by combining components like LEGO blocks!
-
- Click components to add them to your pipeline. Each component transforms the gradient in a specific way - stack them to create powerful optimization algorithms!
- """)
-
- # Hidden state for pipeline
- pipeline_state = gr.State(["gradient_input"])
-
- with gr.Row():
- # Left column - Component palette
- with gr.Column(scale=1):
- gr.Markdown("### 🎨 Component Palette")
- gr.Markdown("*Click blocks to add to pipeline*")
-
- # Component buttons organized by category
- for cat_id, category in OPTIMIZER_BLOCKS.items():
- if cat_id == "gradient": # Skip gradient input in palette
- continue
-
- gr.HTML(f'')
-
- with gr.Row():
- for comp in category["components"]:
- btn = gr.Button(
- value=f"{comp['icon']} {comp['name']}", elem_id=f"btn_{comp['id']}", size="sm"
- )
- # Store component ID in button for click handler
- btn.click(
- fn=lambda p, cid=comp["id"]: p + [cid],
- inputs=[pipeline_state],
- outputs=[pipeline_state],
- )
-
- # Recipe selector
- gr.Markdown("### 📚 Preset Recipes")
- recipe_dropdown = gr.Dropdown(choices=list(RECIPES.keys()), value=None, label="Load a preset optimizer")
- load_recipe_btn = gr.Button("Load Recipe", size="sm")
-
- # Center column - Main visualization
- with gr.Column(scale=2):
- # Pipeline builder
- gr.Markdown("### 🔧 Pipeline Builder")
- pipeline_display = gr.HTML()
-
- with gr.Row():
- clear_pipeline_btn = gr.Button("🗑️ Clear Pipeline", size="sm", variant="secondary")
- refresh_btn = gr.Button("🔄 Refresh Display", size="sm")
-
- # Visualization
- viz_plot = gr.Plot(label="")
-
- # Run button
- run_btn = gr.Button("🚀 Run Optimization", variant="primary", size="lg")
-
- # Right column - Parameters and metrics
- with gr.Column(scale=1):
- gr.Markdown("### ⚙️ Parameters")
-
- problem_select = gr.Dropdown(choices=list(PROBLEMS.keys()), value="Rosenbrock (2D)", label="Problem")
-
- lr_slider = gr.Slider(minimum=-4, maximum=-1, value=-2, step=0.1, label="Learning Rate (10^x)")
-
- steps_slider = gr.Slider(minimum=10, maximum=500, value=200, step=10, label="Steps")
-
- # Component-specific parameters
- with gr.Accordion("Advanced Parameters", open=False):
- beta_slider = gr.Slider(minimum=0.0, maximum=0.999, value=0.9, step=0.001, label="Momentum β")
- beta2_slider = gr.Slider(minimum=0.0, maximum=0.999, value=0.999, step=0.001, label="Adam β₂")
- eps_slider = gr.Slider(minimum=1e-8, maximum=1e-4, value=1e-8, step=1e-8, label="Epsilon")
-
- # Weight decay parameters
- weight_decay_slider = gr.Slider(
- minimum=0.0, maximum=0.1, value=0.01, step=0.001, label="Weight Decay"
- )
- ema_beta_slider = gr.Slider(
- minimum=0.9, maximum=0.999, value=0.999, step=0.001, label="EMA β (for weight decay)"
- )
-
- # Gradient clipping
- max_norm_slider = gr.Slider(
- minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Gradient Clip Norm"
- )
-
- # Preconditioning parameters
- shampoo_beta_slider = gr.Slider(
- minimum=0.9, maximum=0.999, value=0.99, step=0.001, label="Shampoo β"
- )
- precond_lr_slider = gr.Slider(
- minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Preconditioner LR"
- )
- precondition_frequency_slider = gr.Slider(
- minimum=1, maximum=100, value=10, step=1, label="Precondition Frequency"
- )
- rank_slider = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Low-rank Approximation Rank")
- param_count_slider = gr.Slider(
- minimum=1000, maximum=100000, value=10000, step=1000, label="Param Count (PSGD LRA)"
- )
-
- # Adaptive LR parameters
- initial_d_slider = gr.Slider(
- minimum=0.1, maximum=10.0, value=1.0, step=0.1, label="Initial D (D-adaptation)"
- )
- lr_lr_slider = gr.Slider(minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Learning Rate LR")
-
- # Special method parameters
- r_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Schedule-Free r")
- weight_lr_power_slider = gr.Slider(
- minimum=0.0, maximum=4.0, value=2.0, step=0.1, label="Weight LR Power"
- )
- sam_step_size_slider = gr.Slider(
- minimum=0.01, maximum=0.5, value=0.05, step=0.01, label="SAM Step Size"
- )
- beta2_scale_slider = gr.Slider(
- minimum=0.5, maximum=1.0, value=0.8, step=0.01, label="PALM β₂ Scale"
- )
-
- # Sign SGD parameters
- graft_checkbox = gr.Checkbox(value=True, label="Graft (Sign SGD)")
-
- gr.Markdown("### 📊 Metrics")
-
- final_loss_display = gr.Textbox(label="Final Loss", value="-")
- convergence_display = gr.Textbox(label="Steps to Converge", value="-")
-
- # Footer
- gr.Markdown("""
- ---
- ### 💡 How it works:
-
- 1. **Start with Gradient** - Every pipeline begins with raw gradients
- 2. **Add Transforms** - Click components to add them to your pipeline
- 3. **Order Matters** - Components are applied in sequence
- 4. **Run & Compare** - See how different combinations perform!
-
- **Example combinations:**
- - Adam = Gradient → Adam Scaling
- - AdamW = Gradient → Adam Scaling → Weight Decay
- - Momentum SGD = Gradient → Heavy Ball
- """)
-
- # Event handlers
- def update_display(pipeline):
- return create_pipeline_display(pipeline)
-
- def load_recipe(recipe_name):
- if recipe_name in RECIPES:
- return RECIPES[recipe_name].copy()
- return ["gradient_input"]
-
- def clear_pipeline():
- return ["gradient_input"]
-
- def remove_block(pipeline, index):
- """Remove a block from the pipeline"""
- if 0 <= index < len(pipeline) and pipeline[index] != "gradient_input":
- new_pipeline = pipeline.copy()
- new_pipeline.pop(index)
- return new_pipeline
- return pipeline
-
- def run_optimization_handler(
- problem,
- pipeline,
- steps,
- lr,
- beta,
- beta2,
- eps,
- weight_decay,
- ema_beta,
- max_norm,
- shampoo_beta,
- precond_lr,
- precondition_frequency,
- rank,
- param_count,
- initial_d,
- lr_lr,
- r,
- weight_lr_power,
- sam_step_size,
- beta2_scale,
- graft,
- ):
- """Run optimization with current pipeline"""
- if not pipeline or len(pipeline) == 1:
- pipeline = ["gradient_input"]
-
- # Run optimization with all parameters
- fig, metrics = run_optimization(
- problem,
- pipeline,
- steps,
- lr=lr,
- beta=beta,
- betas=(beta, beta2),
- beta2=beta2,
- eps=eps,
- weight_decay_to_ema=weight_decay,
- weight_decay_to_init=weight_decay,
- ema_beta=ema_beta,
- max_norm=max_norm,
- shampoo_beta=shampoo_beta,
- precond_lr=precond_lr,
- precondition_frequency=precondition_frequency,
- rank=rank,
- param_count=param_count,
- initial_d=initial_d,
- lr_lr=lr_lr,
- r=r,
- weight_lr_power=weight_lr_power,
- sam_step_size=sam_step_size,
- beta2_scale=beta2_scale,
- graft=graft,
- )
-
- return fig, f"{metrics['final_loss']:.2e}", str(metrics["steps_to_converge"])
-
- # Connect events
- # Update display when pipeline changes
- pipeline_state.change(fn=update_display, inputs=[pipeline_state], outputs=[pipeline_display])
-
- # Recipe loading
- load_recipe_btn.click(fn=load_recipe, inputs=[recipe_dropdown], outputs=[pipeline_state])
-
- # Clear pipeline
- clear_pipeline_btn.click(fn=clear_pipeline, outputs=[pipeline_state])
-
- # Refresh display
- refresh_btn.click(fn=update_display, inputs=[pipeline_state], outputs=[pipeline_display])
-
- # Run optimization
- run_btn.click(
- fn=run_optimization_handler,
- inputs=[
- problem_select,
- pipeline_state,
- steps_slider,
- lr_slider,
- beta_slider,
- beta2_slider,
- eps_slider,
- weight_decay_slider,
- ema_beta_slider,
- max_norm_slider,
- shampoo_beta_slider,
- precond_lr_slider,
- precondition_frequency_slider,
- rank_slider,
- param_count_slider,
- initial_d_slider,
- lr_lr_slider,
- r_slider,
- weight_lr_power_slider,
- sam_step_size_slider,
- beta2_scale_slider,
- graft_checkbox,
- ],
- outputs=[viz_plot, final_loss_display, convergence_display],
- )
-
- # Add JavaScript for removing blocks
- app.load(
- None,
- None,
- None,
- js="""
- function() {
- // Function to remove pipeline blocks
- window.removePipelineBlock = function(index) {
- // This would need to trigger a Gradio event
- console.log('Remove block at index:', index);
- // In a real implementation, this would update the pipeline state
- };
-
- console.log('Playground initialized');
- }
- """,
- )
-
- # Initialize display
- app.load(
- fn=lambda: (
- create_pipeline_display(["gradient_input"]),
- run_optimization("Rosenbrock (2D)", ["gradient_input", "adam_scale"], 200, lr=-2)[0],
- ),
- outputs=[pipeline_display, viz_plot],
- )
-
- return app
-
-
-if __name__ == "__main__":
- app = create_app()
- app.launch(share=False)
diff --git a/interactive/static/init-globals.js b/interactive/static/init-globals.js
deleted file mode 100644
index b820bc9..0000000
--- a/interactive/static/init-globals.js
+++ /dev/null
@@ -1,2 +0,0 @@
-// This file initializes global variables for the node editor
-// It will be populated by Python when the page loads
diff --git a/interactive/static/node-editor.js b/interactive/static/node-editor.js
deleted file mode 100644
index 525597f..0000000
--- a/interactive/static/node-editor.js
+++ /dev/null
@@ -1,570 +0,0 @@
-console.log('Node editor script starting execution...');
-
-class NodeEditor {
- constructor(containerId) {
- this.container = document.getElementById(containerId);
- this.canvas = document.getElementById('canvas');
- this.palette = document.getElementById('palette');
- this.inspector = document.getElementById('inspector');
- this.nodes = [];
- this.connections = [];
- this.selectedNode = null;
- this.draggingNode = null;
- this.connecting = null;
- this.nodeIdCounter = 0;
- this.offset = { x: 0, y: 0 };
-
- this.init();
- }
-
- init() {
- // Set up SVG for connections
- this.svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg');
- this.svg.style.position = 'absolute';
- this.svg.style.width = '100%';
- this.svg.style.height = '100%';
- this.svg.style.pointerEvents = 'none';
- this.svg.style.zIndex = '1';
- this.svg.style.top = '0';
- this.svg.style.left = '0';
- this.canvas.appendChild(this.svg);
-
- // Set up event listeners
- this.setupPalette();
- this.setupCanvas();
-
- // Add gradient input node by default
- this.addNode('gradient_input', 100, 300);
- }
-
- setupPalette() {
- // Make palette items draggable
- const items = this.palette.querySelectorAll('.palette-item');
- items.forEach(item => {
- item.draggable = true;
- item.addEventListener('dragstart', (e) => this.onPaletteDragStart(e));
- item.addEventListener('dragend', (e) => this.onPaletteDragEnd(e));
- });
- }
-
- setupCanvas() {
- this.canvas.addEventListener('dragover', (e) => e.preventDefault());
- this.canvas.addEventListener('drop', (e) => this.onCanvasDrop(e));
- this.canvas.addEventListener('mousedown', (e) => this.onCanvasMouseDown(e));
- this.canvas.addEventListener('mousemove', (e) => this.onCanvasMouseMove(e));
- this.canvas.addEventListener('mouseup', (e) => this.onCanvasMouseUp(e));
- }
-
- onPaletteDragStart(e) {
- e.target.classList.add('dragging');
- e.dataTransfer.effectAllowed = 'copy';
- e.dataTransfer.setData('nodeType', e.target.dataset.type);
- }
-
- onPaletteDragEnd(e) {
- e.target.classList.remove('dragging');
- }
-
- onCanvasDrop(e) {
- e.preventDefault();
- const nodeType = e.dataTransfer.getData('nodeType');
- if (nodeType) {
- const rect = this.canvas.getBoundingClientRect();
- const x = e.clientX - rect.left - 80;
- const y = e.clientY - rect.top - 30;
- this.addNode(nodeType, x, y);
- this.updateCodePreview();
- }
- }
-
- addNode(type, x, y) {
- console.log(`Adding node: type=${type}, x=${x}, y=${y}`);
- const nodeId = `node-${this.nodeIdCounter++}`;
- const nodeData = window.CHAINABLE_FUNCTIONS[type] || {
- description: type === 'gradient_input' ? 'Gradient Input' : 'Unknown',
- color: '#666',
- inputs: type === 'gradient_input' ? [] : ['grad'],
- outputs: type === 'gradient_input' ? ['grad'] : []
- };
-
- const node = document.createElement('div');
- node.className = 'node';
- node.id = nodeId;
- node.style.left = x + 'px';
- node.style.top = y + 'px';
- node.style.setProperty('--node-color', nodeData.color);
- node.style.borderLeftColor = nodeData.color;
- node.style.borderLeftWidth = '4px';
- console.log('Created node element:', node);
-
- // Special styling for gradient input
- if (type === 'gradient_input') {
- node.classList.add('gradient-input');
- }
-
- node.innerHTML = `
- ${type === 'gradient_input' ? 'Gradient Input' : type.replace(/_/g, ' ')}
- ${nodeData.description}
- ${nodeData.inputs.length ? '' : ''}
- ${nodeData.outputs.length ? '' : ''}
- `;
-
- console.log('Canvas element:', this.canvas);
- console.log('Canvas dimensions:', this.canvas.offsetWidth, 'x', this.canvas.offsetHeight);
- this.canvas.appendChild(node);
- console.log('Node appended to canvas. Total nodes in canvas:', this.canvas.querySelectorAll('.node').length);
-
- // Set up node interactions
- node.addEventListener('mousedown', (e) => this.onNodeMouseDown(e, nodeId));
-
- // Set up port interactions
- const inputPort = node.querySelector('.node-port-input');
- const outputPort = node.querySelector('.node-port-output');
-
- if (inputPort) {
- inputPort.addEventListener('mousedown', (e) => this.onPortMouseDown(e, nodeId, 'input'));
- }
- if (outputPort) {
- outputPort.addEventListener('mousedown', (e) => this.onPortMouseDown(e, nodeId, 'output'));
- }
-
- // Store node data
- this.nodes.push({
- id: nodeId,
- type: type,
- element: node,
- x: x,
- y: y,
- params: nodeData.params || {}
- });
-
- return nodeId;
- }
-
- onNodeMouseDown(e, nodeId) {
- if (e.target.classList.contains('node-port')) return;
-
- e.preventDefault();
- this.selectNode(nodeId);
-
- const node = this.nodes.find(n => n.id === nodeId);
- this.draggingNode = node;
-
- const rect = node.element.getBoundingClientRect();
- const canvasRect = this.canvas.getBoundingClientRect();
- this.offset = {
- x: e.clientX - rect.left + canvasRect.left,
- y: e.clientY - rect.top + canvasRect.top
- };
- }
-
- onPortMouseDown(e, nodeId, portType) {
- e.preventDefault();
- e.stopPropagation();
-
- const port = e.target;
- const rect = port.getBoundingClientRect();
- const canvasRect = this.canvas.getBoundingClientRect();
-
- this.connecting = {
- nodeId: nodeId,
- portType: portType,
- startX: rect.left + rect.width / 2 - canvasRect.left,
- startY: rect.top + rect.height / 2 - canvasRect.top
- };
-
- // Create preview connection
- const line = document.createElementNS('http://www.w3.org/2000/svg', 'path');
- line.classList.add('connection-preview');
- line.id = 'preview-connection';
- this.svg.appendChild(line);
- }
-
- onCanvasMouseMove(e) {
- const rect = this.canvas.getBoundingClientRect();
- const x = e.clientX - rect.left;
- const y = e.clientY - rect.top;
-
- if (this.draggingNode) {
- this.draggingNode.element.style.left = (x - this.offset.x) + 'px';
- this.draggingNode.element.style.top = (y - this.offset.y) + 'px';
- this.draggingNode.x = x - this.offset.x;
- this.draggingNode.y = y - this.offset.y;
- this.updateConnections();
- }
-
- if (this.connecting) {
- const preview = document.getElementById('preview-connection');
- if (preview) {
- const path = this.createConnectionPath(
- this.connecting.startX,
- this.connecting.startY,
- x,
- y
- );
- preview.setAttribute('d', path);
- }
- }
- }
-
- onCanvasMouseUp(e) {
- if (this.connecting) {
- const preview = document.getElementById('preview-connection');
- if (preview) preview.remove();
-
- // Check if we're over a port
- const target = document.elementFromPoint(e.clientX, e.clientY);
- if (target && target.classList.contains('node-port')) {
- const targetNode = target.closest('.node');
- const targetNodeId = targetNode.id;
- const targetPortType = target.classList.contains('node-port-input') ? 'input' : 'output';
-
- // Validate connection
- if (this.canConnect(this.connecting.nodeId, this.connecting.portType, targetNodeId, targetPortType)) {
- this.addConnection(this.connecting.nodeId, this.connecting.portType, targetNodeId, targetPortType);
- this.updateCodePreview();
- }
- }
-
- this.connecting = null;
- }
-
- this.draggingNode = null;
- }
-
- canConnect(fromNodeId, fromPortType, toNodeId, toPortType) {
- // Can't connect to same node
- if (fromNodeId === toNodeId) return false;
-
- // Must connect output to input
- if (fromPortType === toPortType) return false;
-
- // Check if connection already exists
- const exists = this.connections.some(c =>
- (c.from.nodeId === fromNodeId && c.to.nodeId === toNodeId) ||
- (c.from.nodeId === toNodeId && c.to.nodeId === fromNodeId)
- );
-
- return !exists;
- }
-
- addConnection(fromNodeId, fromPortType, toNodeId, toPortType) {
- // Ensure output connects to input
- if (fromPortType === 'input') {
- [fromNodeId, toNodeId] = [toNodeId, fromNodeId];
- [fromPortType, toPortType] = [toPortType, fromPortType];
- }
-
- const connection = {
- from: { nodeId: fromNodeId, portType: fromPortType },
- to: { nodeId: toNodeId, portType: toPortType },
- element: null
- };
-
- const line = document.createElementNS('http://www.w3.org/2000/svg', 'path');
- line.classList.add('connection');
- line.classList.add('animated');
- this.svg.appendChild(line);
-
- connection.element = line;
- this.connections.push(connection);
-
- this.updateConnections();
- }
-
- updateConnections() {
- this.connections.forEach(conn => {
- const fromNode = this.nodes.find(n => n.id === conn.from.nodeId);
- const toNode = this.nodes.find(n => n.id === conn.to.nodeId);
-
- if (fromNode && toNode) {
- const fromPort = fromNode.element.querySelector('.node-port-output');
- const toPort = toNode.element.querySelector('.node-port-input');
-
- if (fromPort && toPort) {
- const fromRect = fromPort.getBoundingClientRect();
- const toRect = toPort.getBoundingClientRect();
- const canvasRect = this.canvas.getBoundingClientRect();
-
- const x1 = fromRect.left + fromRect.width / 2 - canvasRect.left;
- const y1 = fromRect.top + fromRect.height / 2 - canvasRect.top;
- const x2 = toRect.left + toRect.width / 2 - canvasRect.left;
- const y2 = toRect.top + toRect.height / 2 - canvasRect.top;
-
- const path = this.createConnectionPath(x1, y1, x2, y2);
- conn.element.setAttribute('d', path);
-
- // Store path coordinates for particle animation
- conn.pathCoords = { x1, y1, x2, y2 };
- }
- }
- });
- }
-
- // Animate gradient flow particles
- startGradientFlow() {
- if (this.flowInterval) return;
-
- this.flowInterval = setInterval(() => {
- this.connections.forEach(conn => {
- if (conn.pathCoords) {
- this.createFlowParticle(conn.pathCoords);
- }
- });
- }, 500);
- }
-
- stopGradientFlow() {
- if (this.flowInterval) {
- clearInterval(this.flowInterval);
- this.flowInterval = null;
- }
-
- // Remove all particles
- const particles = this.canvas.querySelectorAll('.flow-particle');
- particles.forEach(p => p.remove());
- }
-
- createFlowParticle(coords) {
- const particle = document.createElement('div');
- particle.className = 'flow-particle';
- particle.style.left = coords.x1 + 'px';
- particle.style.top = coords.y1 + 'px';
- this.canvas.appendChild(particle);
-
- // Animate along bezier curve
- const duration = 2000;
- const startTime = Date.now();
-
- const animate = () => {
- const elapsed = Date.now() - startTime;
- const t = Math.min(elapsed / duration, 1);
-
- if (t >= 1) {
- particle.remove();
- return;
- }
-
- // Calculate position on bezier curve
- const dx = Math.abs(coords.x2 - coords.x1);
- const cp1x = coords.x1 + dx * 0.5;
- const cp2x = coords.x2 - dx * 0.5;
-
- // Bezier curve formula
- const x = Math.pow(1-t, 3) * coords.x1 +
- 3 * Math.pow(1-t, 2) * t * cp1x +
- 3 * (1-t) * Math.pow(t, 2) * cp2x +
- Math.pow(t, 3) * coords.x2;
-
- const y = Math.pow(1-t, 3) * coords.y1 +
- 3 * Math.pow(1-t, 2) * t * coords.y1 +
- 3 * (1-t) * Math.pow(t, 2) * coords.y2 +
- Math.pow(t, 3) * coords.y2;
-
- particle.style.left = x + 'px';
- particle.style.top = y + 'px';
- particle.style.opacity = Math.sin(t * Math.PI);
-
- requestAnimationFrame(animate);
- };
-
- requestAnimationFrame(animate);
- }
-
- createConnectionPath(x1, y1, x2, y2) {
- const dx = Math.abs(x2 - x1);
- const cp1x = x1 + dx * 0.5;
- const cp2x = x2 - dx * 0.5;
- return `M ${x1} ${y1} C ${cp1x} ${y1}, ${cp2x} ${y2}, ${x2} ${y2}`;
- }
-
- selectNode(nodeId) {
- // Deselect previous
- if (this.selectedNode) {
- this.selectedNode.element.classList.remove('selected');
- }
-
- // Select new
- const node = this.nodes.find(n => n.id === nodeId);
- if (node) {
- node.element.classList.add('selected');
- this.selectedNode = node;
- this.showInspector(node);
- }
- }
-
- showInspector(node) {
- this.inspector.classList.add('active');
-
- const nodeData = window.CHAINABLE_FUNCTIONS[node.type];
- if (!nodeData || !nodeData.params) return;
-
- // Build inspector content
- let html = `${node.type.replace(/_/g, ' ')}
`;
-
- Object.entries(nodeData.params).forEach(([key, defaultValue]) => {
- const value = node.params[key] || defaultValue;
- html += `
-
- `;
- });
-
- this.inspector.innerHTML = html;
-
- // Set up parameter change listeners
- this.inspector.querySelectorAll('.inspector-input').forEach(input => {
- input.addEventListener('change', (e) => {
- const param = e.target.dataset.param;
- let value = e.target.value;
-
- // Convert to appropriate type
- if (e.target.type === 'number') {
- value = parseFloat(value);
- }
-
- node.params[param] = value;
- this.updateCodePreview();
- });
- });
- }
-
- updateCodePreview() {
- const codePreview = document.getElementById('code-output');
- if (!codePreview) return;
-
- // Generate Python code from pipeline
- const pipelineData = this.exportPipeline();
- const nodes = pipelineData.nodes.filter(n => n.type !== 'gradient_input');
-
- if (nodes.length === 0) {
- codePreview.textContent = '# Add optimizer components to build your pipeline';
- return;
- }
-
- let code = 'optimizer = BaseOpt(\n';
- code += ' model.parameters(),\n';
- code += ' lr=0.001,\n';
-
- // Add relevant parameters
- const params = {};
- nodes.forEach(node => {
- Object.assign(params, node.params);
- });
-
- Object.entries(params).forEach(([key, value]) => {
- if (typeof value === 'string') {
- code += ` ${key}="${value}",\n`;
- } else {
- code += ` ${key}=${value},\n`;
- }
- });
-
- code += ' fns=[\n';
- nodes.forEach(node => {
- code += ` ${node.type},\n`;
- });
- code += ' ]\n';
- code += ')';
-
- codePreview.textContent = code;
-
- // Update hidden pipeline data
- const pipelineInput = document.getElementById('pipeline-data');
- if (pipelineInput) {
- pipelineInput.value = JSON.stringify(pipelineData);
- }
- }
-
- exportPipeline() {
- return {
- nodes: this.nodes.map(n => ({
- id: n.id,
- type: n.type,
- x: n.x,
- y: n.y,
- params: n.params
- })),
- connections: this.connections.map(c => ({
- from: c.from,
- to: c.to
- }))
- };
- }
-
- loadRecipe(recipe) {
- // Clear existing nodes except gradient input
- this.nodes = this.nodes.filter(n => n.type === 'gradient_input');
- this.connections = [];
- this.selectedNode = null;
-
- // Clear canvas except gradient input
- const nodesToRemove = this.canvas.querySelectorAll('.node:not([id="node-0"])');
- nodesToRemove.forEach(n => n.remove());
-
- // Clear connections
- this.svg.innerHTML = '';
-
- // Add nodes from recipe
- let lastNodeId = 'node-0'; // gradient input
- let x = 300;
- const y = 300;
-
- recipe.forEach((item, index) => {
- const nodeId = this.addNode(item.name, x, y);
- const node = this.nodes.find(n => n.id === nodeId);
-
- // Set parameters
- if (item.params) {
- node.params = item.params;
- }
-
- // Connect to previous node
- this.addConnection(lastNodeId, 'output', nodeId, 'input');
-
- lastNodeId = nodeId;
- x += 200;
- });
-
- this.updateCodePreview();
- }
-}
-
-// Initialize when DOM is ready
-function initializeNodeEditor() {
- try {
- console.log('Attempting to initialize NodeEditor...');
- const editorElement = document.getElementById('node-editor');
- if (!editorElement) {
- console.error('node-editor element not found! Retrying in 100ms...');
- setTimeout(initializeNodeEditor, 100);
- return;
- }
- console.log('Found node-editor element:', editorElement);
- console.log('Editor dimensions:', editorElement.offsetWidth, 'x', editorElement.offsetHeight);
-
- if (window.nodeEditor) {
- console.log('NodeEditor already initialized');
- return;
- }
-
- window.nodeEditor = new NodeEditor('node-editor');
- console.log('NodeEditor initialized successfully');
- } catch (error) {
- console.error('Error initializing NodeEditor:', error);
- }
-}
-
-// Try multiple initialization methods
-document.addEventListener('DOMContentLoaded', initializeNodeEditor);
-// Also try after a delay in case content loads after DOMContentLoaded
-setTimeout(initializeNodeEditor, 100);
-setTimeout(initializeNodeEditor, 500);
-setTimeout(initializeNodeEditor, 1000);
diff --git a/pyproject.toml b/pyproject.toml
index 36916e8..54c4d55 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "heavyball"
-description = "Compile-first PyTorch optimizer library — AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features"
+description = "Compile-first PyTorch optimizer library - AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features"
version = "2.3.2"
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
classifiers = ["Intended Audience :: Developers",
diff --git a/scripts/migrate_optimizer_state.py b/scripts/migrate_optimizer_state.py
index 8931cc7..0c0e3e0 100644
--- a/scripts/migrate_optimizer_state.py
+++ b/scripts/migrate_optimizer_state.py
@@ -1,18 +1,21 @@
#!/usr/bin/env python3
"""
-Utility to migrate HeavyBall 1.x optimizer state dicts to the 2.0.0 layout.
+Utility to migrate HeavyBall optimizer state dicts to the current (3.x) layout.
-The script rewrites per-parameter state keys to the new transform-indexed names,
-reshapes state storage so each parameter-view owns its own dictionary, and
-injects the HeavyBall-specific metadata block expected by 2.0.0 optimizers.
+Supports two migration paths:
+- 1.x → 3.x: rewrites per-parameter state keys, reshapes storage, fixes param groups and metadata
+- 2.x → 3.x: renames foreach→multi_tensor in param groups, strips stale metadata
+
+Version auto-detection:
+- 1.x: flat per-parameter state (not nested by view index)
+- 2.x: nested state with 'foreach' in param groups
+- 3.x: nested state with 'multi_tensor' in param groups (already current)
"""
from __future__ import annotations
import functools
import importlib
-import pickle
-import random
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
@@ -20,6 +23,42 @@
import torch
import typer
+_CLASS_RENAMES = {
+ "ForeachAdamW": "AdamW",
+ "ForeachNAdam": "NAdam",
+ "ForeachAdEMAMix": "AdEMAMix",
+ "ForeachAdamC": "AdamC",
+ "ForeachRMSprop": "RMSprop",
+ "ForeachSFAdamW": "SFAdamW",
+ "ForeachADOPT": "ADOPT",
+ "ForeachMuon": "Muon",
+ "ForeachLaProp": "LaProp",
+ "ForeachSOAP": "SOAP",
+ "ForeachSOAPNAdam": "SOAPNAdam",
+ "ForeachSOAPAdEMAMix": "SOAPAdEMAMix",
+ "ForeachSignLaProp": "SignLaProp",
+ "ForeachSOLP": "SOLP",
+ "ForeachPSGDKron": "PSGDKron",
+ "ForeachPSGDLRA": "PSGDLRA",
+ # Deleted subclasses → parent (all were __init__-only default overrides)
+ "PaLMForeachSFAdamW": "SFAdamW",
+ "PaLMForeachSOAP": "SOAP",
+ "PrecondScheduleForeachSOAP": "SOAP",
+ "PrecondSchedulePaLMForeachSOAP": "SOAP",
+ "ForeachPurePSGD": "PSGDKron",
+ "ForeachCachedPSGDKron": "PSGDKron",
+ "ForeachCachedDelayedPSGDKron": "PSGDKron",
+ "ForeachDelayedPSGD": "PSGDKron",
+ "ForeachCachedNewtonPSGD": "PSGDKron",
+ "NewtonHybrid2PSGDKron": "PSGDKron",
+ "ForeachDelayedPSGDLRA": "PSGDLRA",
+ "ForeachNewtonPSGDLRA": "PSGDLRA",
+ "NewtonHybrid2PSGDLRA": "PSGDLRA",
+}
+
+_REMOVED_GROUP_KEYS = frozenset({"stochastic_schedule"})
+_REMOVED_META_KEYS = frozenset({"stochastic_schedule", "precond_rng"})
+
@dataclass
class TransformMapping:
@@ -28,16 +67,37 @@ class TransformMapping:
transform_idx: int
+def _resolve_class_name(class_name: str) -> str:
+ return _CLASS_RENAMES.get(class_name, class_name)
+
+
def _load_optimizer_class(qualified_name: str):
if "." in qualified_name:
module_name, class_name = qualified_name.rsplit(".", 1)
else:
module_name, class_name = "heavyball", qualified_name
+ class_name = _resolve_class_name(class_name)
module = importlib.import_module(module_name)
try:
return getattr(module, class_name)
except AttributeError as exc:
- raise ValueError(f"Optimizer class '{qualified_name}' not found") from exc
+ raise ValueError(f"Optimizer class '{module_name}.{class_name}' not found") from exc
+
+
+def _detect_version(state_dict: Dict[str, Any]) -> int:
+ nested = False
+ for entry in state_dict.get("state", {}).values():
+ if isinstance(entry, dict):
+ nested = any(isinstance(v, dict) for v in entry.values())
+ if not nested:
+ return 1
+ break
+
+ groups = state_dict.get("param_groups", [])
+ has_multi_tensor = any(isinstance(g, dict) and "multi_tensor" in g for g in groups)
+ if nested:
+ return 3 if has_multi_tensor else 2
+ return 3 if has_multi_tensor else 1
def _guess_tensor_meta(state_entry: Dict[str, Any]) -> Tuple[Tuple[int, ...], torch.dtype]:
@@ -69,8 +129,9 @@ def _build_dummy_parameters(state: Dict[int, Dict[str, Any]], param_groups: Sequ
def _normalise_group_options(group: Dict[str, Any]) -> Dict[str, Any]:
options: Dict[str, Any] = {}
for key, value in group.items():
- if key == "params":
+ if key == "params" or key in _REMOVED_GROUP_KEYS:
continue
+ key = "multi_tensor" if key == "foreach" else key
if isinstance(value, list) and key in {"betas", "weight_decay_steps"}:
options[key] = tuple(value)
else:
@@ -107,6 +168,12 @@ def walk(queue: Iterable[Any]):
elif isinstance(current, C.Parallel):
for branch in current.branches:
stack.extend(branch)
+ elif isinstance(current, C.Route):
+ for _, fns in current.routes:
+ if fns:
+ stack.extend(fns)
+ if current.default:
+ stack.extend(current.default)
elif isinstance(current, (list, tuple)):
stack.extend(current)
@@ -180,7 +247,24 @@ def _migrate_single_state(entry: Dict[str, Any], mappings: List[TransformMapping
return migrated
-def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]:
+def _migrate_v2_to_v3(state_dict: Dict[str, Any]) -> None:
+ for group in state_dict.get("param_groups", []):
+ if not isinstance(group, dict):
+ continue
+ if "foreach" in group:
+ group["multi_tensor"] = group.pop("foreach")
+ for key in _REMOVED_GROUP_KEYS:
+ group.pop(key, None)
+
+ hb = state_dict.get("heavyball", {})
+ for key in _REMOVED_META_KEYS:
+ hb.pop(key, None)
+ inner = hb.get("inner_group", {})
+ for key in _REMOVED_META_KEYS:
+ inner.pop(key, None)
+
+
+def _migrate_v1_state(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]:
opt_cls = _load_optimizer_class(optimizer_class)
optimizer, _ = _instantiate_optimizer(opt_cls, old_state)
template = optimizer.state_dict()
@@ -199,11 +283,19 @@ def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[
heavyball_meta = migrated.setdefault("heavyball", template.get("heavyball", {}))
if "inner_group" not in heavyball_meta:
- heavyball_meta["inner_group"] = {"stochastic_schedule": None}
- if "stochastic_schedule" not in heavyball_meta:
- heavyball_meta["stochastic_schedule"] = None
- if "precond_rng" not in heavyball_meta:
- heavyball_meta["precond_rng"] = pickle.dumps(random.Random(0x12312))
+ heavyball_meta["inner_group"] = {}
+ return migrated
+
+
+def migrate_state_dict(old_state: Dict[str, Any], optimizer_class: str) -> Dict[str, Any]:
+ version = _detect_version(old_state)
+ if version == 1:
+ migrated = _migrate_v1_state(old_state, optimizer_class)
+ elif version == 2:
+ migrated = dict(old_state)
+ else:
+ return dict(old_state)
+ _migrate_v2_to_v3(migrated)
return migrated
@@ -221,7 +313,7 @@ def _resolve_state_container(root: Dict[str, Any], key_path: Sequence[str]) -> D
app = typer.Typer(help="Utilities for migrating HeavyBall optimizer checkpoints.")
-@app.command(help="Migrate a HeavyBall optimizer state dict to the 2.0.0 layout.")
+@app.command(help="Migrate a HeavyBall optimizer state dict to the current (3.x) layout.")
def migrate(
checkpoint: Path = typer.Argument(
...,
@@ -234,7 +326,7 @@ def migrate(
),
optimizer_class: str = typer.Argument(
...,
- help="Optimizer class to instantiate (e.g., heavyball.AdamW)",
+ help="Optimizer class name (e.g., heavyball.AdamW). Old names like ForeachAdamW are resolved automatically.",
),
state_key: str = typer.Option(
"optimizer",
diff --git a/test/test_compile_step.py b/test/test_compile_step.py
index de210e7..eb7bbaf 100644
--- a/test/test_compile_step.py
+++ b/test/test_compile_step.py
@@ -4,6 +4,7 @@
import torch
import heavyball
+from heavyball.chainable import ChainOpt, WarmupGuard, _walk_fns
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -81,3 +82,33 @@ def test_compile_step_matches_eager(opt_name, opt_cls):
for p_ref, p_test in zip(model_ref.parameters(), model_test.parameters()):
diff = (p_ref.data - p_test.data).abs().max().item()
assert diff < 1e-4, f"compile_step diverged: max_diff={diff}"
+
+
+def _max_warmup(opt):
+ return max((len(ft.warmup_fns) for ft in _walk_fns(opt.fns) if isinstance(ft, WarmupGuard)), default=0)
+
+
+@pytest.mark.parametrize("opt_name,opt_cls", _optimizer_params())
+def test_needs_init_clears(opt_name, opt_cls):
+ """_needs_init must become False after max_warmup + 1 steps for all ChainOpt optimizers.
+
+ Catches bugs where Route-based or warmup_guard-based optimizers permanently
+ force eager mode because different params accumulate different is_initialized
+ sets that never individually cover _transform_ids.
+ """
+ if not issubclass(opt_cls, ChainOpt):
+ pytest.skip("not a ChainOpt")
+
+ kwargs = dict(EXTRA_KWARGS.get(opt_name, {}))
+ model = _make_model()
+ opt = opt_cls(model.parameters(), **kwargs)
+ n = _max_warmup(opt) + 1
+
+ _run_steps(model, opt, n=n)
+
+ for group in opt.param_groups:
+ state = [opt.state_(p) for p in group["params"]]
+ assert not opt._needs_init(state), (
+ f"{opt_name}: _needs_init stuck True after {n} steps | "
+ f"compile_step will never engage"
+ )
diff --git a/test/test_ecc.py b/test/test_ecc.py
index 163d4f9..45c4758 100644
--- a/test/test_ecc.py
+++ b/test/test_ecc.py
@@ -362,7 +362,7 @@ def test_ecc_live_path_nonzero_correction():
p = list(m.parameters())[0]
st, ecc_keys = _ecc_keys(o, p)
for ek in ecc_keys:
- assert st[ek].any(), f"ECC correction '{ek}' is all zeros — stochastic_round_ likely mutating source"
+ assert st[ek].any(), f"ECC correction '{ek}' is all zeros - stochastic_round_ likely mutating source"
del m, o
clean()
diff --git a/test/test_migrate_cli.py b/test/test_migrate_cli.py
index 253c6d8..d3f3de7 100644
--- a/test/test_migrate_cli.py
+++ b/test/test_migrate_cli.py
@@ -1,5 +1,7 @@
import importlib.util
import pathlib
+import pickle
+import random
import sys
import pytest
@@ -14,164 +16,711 @@
sys.modules[MODULE_NAME] = migrate_script
SPEC.loader.exec_module(migrate_script) # type: ignore[arg-type]
+# ---------------------------------------------------------------------------
+# Rename tables (exhaustive, mirrors _CLASS_RENAMES in the script)
+# ---------------------------------------------------------------------------
+
+_DIRECT_RENAMES = [
+ ("ForeachAdamW", "AdamW"),
+ ("ForeachNAdam", "NAdam"),
+ ("ForeachAdEMAMix", "AdEMAMix"),
+ ("ForeachAdamC", "AdamC"),
+ ("ForeachRMSprop", "RMSprop"),
+ ("ForeachSFAdamW", "SFAdamW"),
+ ("ForeachADOPT", "ADOPT"),
+ ("ForeachMuon", "Muon"),
+ ("ForeachLaProp", "LaProp"),
+ ("ForeachSOAP", "SOAP"),
+ ("ForeachSOAPNAdam", "SOAPNAdam"),
+ ("ForeachSOAPAdEMAMix", "SOAPAdEMAMix"),
+ ("ForeachSignLaProp", "SignLaProp"),
+ ("ForeachSOLP", "SOLP"),
+ ("ForeachPSGDKron", "PSGDKron"),
+ ("ForeachPSGDLRA", "PSGDLRA"),
+]
+
+_DELETED_RENAMES = [
+ ("PaLMForeachSFAdamW", "SFAdamW"),
+ ("PaLMForeachSOAP", "SOAP"),
+ ("PrecondScheduleForeachSOAP", "SOAP"),
+ ("PrecondSchedulePaLMForeachSOAP", "SOAP"),
+ ("ForeachPurePSGD", "PSGDKron"),
+ ("ForeachCachedPSGDKron", "PSGDKron"),
+ ("ForeachCachedDelayedPSGDKron", "PSGDKron"),
+ ("ForeachDelayedPSGD", "PSGDKron"),
+ ("ForeachCachedNewtonPSGD", "PSGDKron"),
+ ("NewtonHybrid2PSGDKron", "PSGDKron"),
+ ("ForeachDelayedPSGDLRA", "PSGDLRA"),
+ ("ForeachNewtonPSGDLRA", "PSGDLRA"),
+ ("NewtonHybrid2PSGDLRA", "PSGDLRA"),
+]
+
+_ALL_RENAMES = _DIRECT_RENAMES + _DELETED_RENAMES
+
+# PSGDLRA uses a different constructor signature (beta not betas, rank param, etc.)
+# so e2e tests using AdamW-style param_groups must exclude it
+_LRA_TARGETS = {"PSGDLRA"}
+_DIRECT_RENAMES_NO_LRA = [(o, n) for o, n in _DIRECT_RENAMES if n not in _LRA_TARGETS]
+_DELETED_RENAMES_NO_LRA = [(o, n) for o, n in _DELETED_RENAMES if n not in _LRA_TARGETS]
+_DIRECT_RENAMES_LRA = [(o, n) for o, n in _DIRECT_RENAMES if n in _LRA_TARGETS]
+_DELETED_RENAMES_LRA = [(o, n) for o, n in _DELETED_RENAMES if n in _LRA_TARGETS]
+
+_PASSTHROUGH = ["AdamW", "NAdam", "PSGDKron", "SGD", "SOAP", "Muon", "ADOPT", "LaProp", "PSGDLRA", "SFAdamW"]
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _flat_state(*shapes):
+ return {
+ i: {
+ "update_by_adam_exp_avg": torch.ones(s),
+ "update_by_adam_exp_avg_sq": torch.full(s, 2.0),
+ "is_initialized": [0],
+ }
+ for i, s in enumerate(shapes)
+ }
+
+
+def _nested_state(*shapes):
+ return {i: {0: {"key_0": torch.zeros(s), "is_initialized": [0]}} for i, s in enumerate(shapes)}
+
+
+def _v1_group(pids):
+ return {
+ "params": pids,
+ "lr": 0.0025,
+ "betas": [0.9, 0.99],
+ "eps": 1e-8,
+ "weight_decay": 0.0,
+ "warmup_steps": 0,
+ "foreach": True,
+ "stochastic_schedule": False,
+ "storage_dtype": "float32",
+ "mars": False,
+ "caution": False,
+ "mars_gamma": 0.0025,
+ "gradient_clipping": "use_default",
+ "update_clipping": "use_default",
+ "palm": "use_default",
+ "beta2_scale": 0.8,
+ }
+
+
+def _v2_group(pids):
+ return {**_v1_group(pids), "__class__": "heavyball.ForeachAdamW"}
+
+
+def _v3_group(pids):
+ g = _v1_group(pids)
+ g.pop("foreach")
+ g.pop("stochastic_schedule")
+ g["multi_tensor"] = True
+ return g
+
+
+def _v2_meta():
+ return {
+ "inner_group": {"stochastic_schedule": None},
+ "stochastic_schedule": None,
+ "precond_rng": pickle.dumps(random.Random(0x12312)),
+ "use_ema": False,
+ }
+
+
+def _v3_meta():
+ return {"inner_group": {}, "use_ema": False}
+
+
+def _load_heavyball_fresh():
+ package_root = pathlib.Path(__file__).resolve().parents[1]
+ heavyball_pkg = package_root / "heavyball"
+ saved = {n: sys.modules[n] for n in list(sys.modules) if n == "heavyball" or n.startswith("heavyball.")}
+ for n in list(sys.modules):
+ if n == "heavyball" or n.startswith("heavyball."):
+ sys.modules.pop(n)
+ spec = importlib.util.spec_from_file_location(
+ "heavyball", heavyball_pkg / "__init__.py", submodule_search_locations=[str(heavyball_pkg)]
+ )
+ mod = importlib.util.module_from_spec(spec)
+ sys.modules["heavyball"] = mod
+ spec.loader.exec_module(mod)
+ return saved
+
+
+def _restore_heavyball(saved):
+ for n in list(sys.modules):
+ if n == "heavyball" or n.startswith("heavyball."):
+ sys.modules.pop(n)
+ sys.modules.update(saved)
+
@pytest.fixture()
def runner():
return CliRunner()
-def test_cli_dry_run_updates_state(monkeypatch, runner, tmp_path):
- checkpoint_path = tmp_path / "ckpt.pt"
- checkpoint_path.touch()
+# ====================================================================
+# _resolve_class_name
+# ====================================================================
- state_container = {"state": {"initial": True}, "param_groups": ["group"]}
- checkpoint = {"optimizer": state_container}
- def fake_load(path, map_location=None):
- return checkpoint
+@pytest.mark.parametrize("old,new", _ALL_RENAMES)
+def test_resolve_class_name_renames(old, new):
+ assert migrate_script._resolve_class_name(old) == new
- def fake_migrate(state, _):
- return {"state": {"migrated": True}, "param_groups": []}
- def fail_save(*args, **kwargs):
- pytest.fail("torch.save should not run during dry-run")
+@pytest.mark.parametrize("name", _PASSTHROUGH)
+def test_resolve_class_name_passthrough(name):
+ assert migrate_script._resolve_class_name(name) == name
- monkeypatch.setattr(migrate_script.torch, "load", fake_load)
- monkeypatch.setattr(migrate_script, "migrate_state_dict", fake_migrate)
- monkeypatch.setattr(migrate_script.torch, "save", fail_save)
- result = runner.invoke(
- migrate_script.app,
- [str(checkpoint_path), "heavyball.Mock", "--state-key", "optimizer", "--dry-run"],
- )
+# ====================================================================
+# _load_optimizer_class
+# ====================================================================
+
+
+@pytest.mark.parametrize("old,new", _ALL_RENAMES)
+def test_load_optimizer_class_old_names(old, new):
+ import heavyball
+
+ assert migrate_script._load_optimizer_class(f"heavyball.{old}") is getattr(heavyball, new)
+
+
+@pytest.mark.parametrize("name", _PASSTHROUGH)
+def test_load_optimizer_class_current_names(name):
+ import heavyball
+
+ assert migrate_script._load_optimizer_class(name) is getattr(heavyball, name)
+
+
+def test_load_optimizer_class_bare_name():
+ import heavyball
+
+ assert migrate_script._load_optimizer_class("AdamW") is heavyball.AdamW
+
+
+def test_load_optimizer_class_invalid_raises():
+ with pytest.raises(ValueError, match="not found"):
+ migrate_script._load_optimizer_class("heavyball.TotallyFakeOptimizer")
+
+
+# ====================================================================
+# _detect_version
+# ====================================================================
+
+_DETECT_CASES = [
+ ("flat_foreach", {"state": _flat_state((2,)), "param_groups": [{"foreach": True}]}, 1),
+ ("flat_multi_tensor", {"state": _flat_state((2,)), "param_groups": [{"multi_tensor": True}]}, 1),
+ ("flat_neither", {"state": _flat_state((2,)), "param_groups": [{}]}, 1),
+ ("flat_multi_group", {"state": _flat_state((2,), (3,)), "param_groups": [{"foreach": True}, {}]}, 1),
+ ("nested_foreach", {"state": _nested_state((2,)), "param_groups": [{"foreach": True}]}, 2),
+ ("nested_foreach_multi_group", {"state": _nested_state((2,), (3,)), "param_groups": [{"foreach": True}, {"foreach": True}]}, 2),
+ ("nested_multi_tensor", {"state": _nested_state((2,)), "param_groups": [{"multi_tensor": True}]}, 3),
+ ("nested_multi_tensor_multi_group", {"state": _nested_state((2,), (3,)), "param_groups": [{"multi_tensor": True}, {}]}, 3),
+ ("empty_state_multi_tensor", {"state": {}, "param_groups": [{"multi_tensor": True}]}, 3),
+ ("empty_state_foreach", {"state": {}, "param_groups": [{"foreach": True}]}, 1),
+ ("empty_state_bare", {"state": {}, "param_groups": [{}]}, 1),
+ ("empty_state_empty_groups", {"state": {}, "param_groups": []}, 1),
+]
+
+
+@pytest.mark.parametrize("name,sd,expected", _DETECT_CASES, ids=[c[0] for c in _DETECT_CASES])
+def test_detect_version(name, sd, expected):
+ assert migrate_script._detect_version(sd) == expected
+
+
+# ====================================================================
+# _normalise_group_options
+# ====================================================================
+
+_NORM_CASES = [
+ ("foreach_renamed", {"foreach": True, "lr": 0.01}, {"multi_tensor": True, "lr": 0.01}),
+ ("foreach_false", {"foreach": False, "lr": 0.01}, {"multi_tensor": False, "lr": 0.01}),
+ ("multi_tensor_passthrough", {"multi_tensor": False, "lr": 0.01}, {"multi_tensor": False, "lr": 0.01}),
+ ("stochastic_stripped", {"foreach": True, "stochastic_schedule": False}, {"multi_tensor": True}),
+ ("betas_tuple", {"betas": [0.9, 0.999]}, {"betas": (0.9, 0.999)}),
+ ("weight_decay_steps_tuple", {"weight_decay_steps": [100, 200]}, {"weight_decay_steps": (100, 200)}),
+ ("params_stripped", {"params": [0, 1], "lr": 0.01}, {"lr": 0.01}),
+ ("empty_group", {}, {}),
+ ("params_only", {"params": [0]}, {}),
+ ("all_removed", {"params": [0], "stochastic_schedule": None}, {}),
+ ("kitchen_sink", {"params": [0, 1], "foreach": True, "stochastic_schedule": True, "lr": 0.01, "betas": [0.9, 0.99]}, {"multi_tensor": True, "lr": 0.01, "betas": (0.9, 0.99)}),
+]
+
+
+@pytest.mark.parametrize("name,group,expected", _NORM_CASES, ids=[c[0] for c in _NORM_CASES])
+def test_normalise_group_options(name, group, expected):
+ assert migrate_script._normalise_group_options(group) == expected
+
+
+# ====================================================================
+# _ensure_set
+# ====================================================================
+
+
+@pytest.mark.parametrize(
+ "inp,expected",
+ [
+ ({1, 2}, {1, 2}),
+ ([3, 4], {3, 4}),
+ ((5,), {5}),
+ (None, set()),
+ (7, {7}),
+ ([], set()),
+ ],
+ ids=["set", "list", "tuple", "none", "scalar", "empty_list"],
+)
+def test_ensure_set(inp, expected):
+ assert migrate_script._ensure_set(inp) == expected
+
+
+# ====================================================================
+# _guess_tensor_meta
+# ====================================================================
+
+
+@pytest.mark.parametrize(
+ "entry,shape,dtype",
+ [
+ ({"a": torch.zeros(3, 4, dtype=torch.float16)}, (3, 4), torch.float16),
+ ({"a": [torch.zeros(2, dtype=torch.bfloat16)]}, (2,), torch.bfloat16),
+ ({"a": "not_a_tensor", "b": torch.ones(5)}, (5,), torch.float32),
+ ({}, (1,), torch.float32),
+ ({"a": 42}, (1,), torch.float32),
+ ],
+ ids=["tensor", "tensor_list", "mixed_finds_tensor", "empty", "no_tensors"],
+)
+def test_guess_tensor_meta(entry, shape, dtype):
+ s, d = migrate_script._guess_tensor_meta(entry)
+ assert s == shape
+ assert d == dtype
+
+
+# ====================================================================
+# _resolve_state_container
+# ====================================================================
+
+
+def test_resolve_state_container_single_key():
+ root = {"optimizer": {"state": {}, "param_groups": []}}
+ assert migrate_script._resolve_state_container(root, ["optimizer"]) is root["optimizer"]
+
+
+def test_resolve_state_container_nested_key():
+ inner = {"state": {}, "param_groups": []}
+ root = {"model": {"training": {"opt": inner}}}
+ assert migrate_script._resolve_state_container(root, ["model", "training", "opt"]) is inner
+
+
+def test_resolve_state_container_missing_key():
+ with pytest.raises(KeyError, match="not found"):
+ migrate_script._resolve_state_container({"a": {}}, ["a", "b"])
+
+
+def test_resolve_state_container_not_optimizer():
+ with pytest.raises(ValueError, match="not an optimizer state dict"):
+ migrate_script._resolve_state_container({"opt": {"just_data": 1}}, ["opt"])
+
+
+# ====================================================================
+# _migrate_v2_to_v3
+# ====================================================================
+
+
+@pytest.mark.parametrize(
+ "meta_keys",
+ [
+ {"stochastic_schedule": None, "precond_rng": pickle.dumps(random.Random(0))},
+ {"stochastic_schedule": None},
+ {"precond_rng": b"rng"},
+ {},
+ ],
+ ids=["both", "stochastic_only", "precond_rng_only", "clean"],
+)
+def test_migrate_v2_to_v3_strips_stale_meta(meta_keys):
+ sd = {
+ "state": _nested_state((3,)),
+ "param_groups": [{"params": [0], "foreach": True, "stochastic_schedule": False}],
+ "heavyball": {"inner_group": {"stochastic_schedule": None}, "use_ema": False, **meta_keys},
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ group = sd["param_groups"][0]
+ assert group["multi_tensor"] is True
+ assert "foreach" not in group
+ assert "stochastic_schedule" not in group
+ hb = sd["heavyball"]
+ for key in ("stochastic_schedule", "precond_rng"):
+ assert key not in hb
+ assert key not in hb["inner_group"]
+ assert hb["use_ema"] is False
+
+
+def test_migrate_v2_to_v3_multi_group():
+ sd = {
+ "state": _nested_state((2,), (3,)),
+ "param_groups": [
+ {"params": [0], "foreach": True, "stochastic_schedule": False},
+ {"params": [1], "foreach": False, "stochastic_schedule": True, "lr": 0.1},
+ ],
+ "heavyball": _v2_meta(),
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ assert sd["param_groups"][0]["multi_tensor"] is True
+ assert sd["param_groups"][1]["multi_tensor"] is False
+ assert sd["param_groups"][1]["lr"] == 0.1
+ for g in sd["param_groups"]:
+ assert "foreach" not in g
+ assert "stochastic_schedule" not in g
+
+
+def test_migrate_v2_to_v3_preserves_non_stale_meta():
+ sd = {
+ "state": _nested_state((2,)),
+ "param_groups": [{"params": [0], "foreach": True}],
+ "heavyball": {**_v2_meta(), "use_ema": True, "ema_decay": 0.005, "compile_step": True},
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ assert sd["heavyball"]["use_ema"] is True
+ assert sd["heavyball"]["ema_decay"] == 0.005
+ assert sd["heavyball"]["compile_step"] is True
+
+
+def test_migrate_v2_to_v3_no_heavyball_key():
+ sd = {
+ "state": _nested_state((2,)),
+ "param_groups": [{"params": [0], "foreach": True}],
+ }
+ migrate_script._migrate_v2_to_v3(sd)
+ assert sd["param_groups"][0]["multi_tensor"] is True
+
+
+# ====================================================================
+# _migrate_single_state
+# ====================================================================
+
+
+@pytest.mark.parametrize("n_views", [1, 2, 5])
+def test_migrate_single_state_rewrites_keys(n_views):
+ mappings = [
+ migrate_script.TransformMapping("exp_avg", "exp_avg_0", 0),
+ migrate_script.TransformMapping("exp_avg_sq", "exp_avg_sq_0", 0),
+ ]
+ entry = {
+ "exp_avg": [torch.ones(2)] * n_views if n_views > 1 else torch.ones(2),
+ "exp_avg_sq": [torch.full((2,), 2.0)] * n_views if n_views > 1 else torch.full((2,), 2.0),
+ "is_initialized": [0],
+ }
+ migrated = migrate_script._migrate_single_state(entry, mappings)
+ for view_idx in range(n_views):
+ bucket = migrated[view_idx]
+ assert "exp_avg_0" in bucket
+ assert "exp_avg_sq_0" in bucket
+ assert "exp_avg" not in bucket
+ assert "exp_avg_sq" not in bucket
+ assert 0 in bucket["is_initialized"]
+
+
+def test_migrate_single_state_empty_mappings():
+ entry = {"some_key": torch.ones(3), "is_initialized": {0, 1}}
+ migrated = migrate_script._migrate_single_state(entry, [])
+ assert "some_key" in migrated[0]
+ assert set(migrated[0]["is_initialized"]) == {0, 1}
+
+
+def test_migrate_single_state_multiple_transforms():
+ mappings = [
+ migrate_script.TransformMapping("exp_avg", "update_by_adam_exp_avg_0", 0),
+ migrate_script.TransformMapping("exp_avg_sq", "update_by_adam_exp_avg_sq_0", 0),
+ migrate_script.TransformMapping("momentum", "heavyball_momentum_momentum_1", 1),
+ ]
+ entry = {
+ "exp_avg": torch.ones(4),
+ "exp_avg_sq": torch.full((4,), 2.0),
+ "momentum": torch.full((4,), 0.5),
+ "is_initialized": [0, 1],
+ }
+ migrated = migrate_script._migrate_single_state(entry, mappings)
+ bucket = migrated[0]
+ assert "update_by_adam_exp_avg_0" in bucket
+ assert "update_by_adam_exp_avg_sq_0" in bucket
+ assert "heavyball_momentum_momentum_1" in bucket
+ assert set(bucket["is_initialized"]) == {0, 1}
+
+
+def test_migrate_single_state_preserves_values():
+ t = torch.arange(6, dtype=torch.float32)
+ mappings = [migrate_script.TransformMapping("old", "new_0", 0)]
+ migrated = migrate_script._migrate_single_state({"old": t, "is_initialized": [0]}, mappings)
+ assert torch.equal(migrated[0]["new_0"], t)
+
+
+# ====================================================================
+# Full migrate_state_dict
+# ====================================================================
+
+
+def test_migrate_v2_state_dict():
+ sd = {
+ "state": _nested_state((4,)),
+ "param_groups": [_v2_group([0])],
+ "heavyball": _v2_meta(),
+ }
+ migrated = migrate_script.migrate_state_dict(sd, "heavyball.ForeachAdamW")
+ group = migrated["param_groups"][0]
+ assert group["multi_tensor"] is True
+ assert "foreach" not in group
+ assert "stochastic_schedule" not in group
+ hb = migrated.get("heavyball", {})
+ assert "stochastic_schedule" not in hb
+ assert "precond_rng" not in hb
+
+
+def test_migrate_v3_is_noop():
+ sd = {
+ "state": _nested_state((4,)),
+ "param_groups": [_v3_group([0])],
+ "heavyball": _v3_meta(),
+ }
+ original_key0 = sd["state"][0][0]["key_0"].clone()
+ migrated = migrate_script.migrate_state_dict(sd, "heavyball.AdamW")
+ assert migrated is not sd
+ assert torch.equal(migrated["state"][0][0]["key_0"], original_key0)
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+
+
+def test_migrate_v2_with_old_class_name():
+ sd = {
+ "state": _nested_state((4,)),
+ "param_groups": [_v2_group([0])],
+ "heavyball": _v2_meta(),
+ }
+ migrated = migrate_script.migrate_state_dict(sd, "ForeachAdamW")
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+
+
+def test_migrate_v2_multi_param():
+ sd = {
+ "state": _nested_state((4, 4), (8,), (2, 2, 2)),
+ "param_groups": [_v2_group([0, 1, 2])],
+ "heavyball": _v2_meta(),
+ }
+ migrated = migrate_script.migrate_state_dict(sd, "AdamW")
+ for pid in (0, 1, 2):
+ assert pid in migrated["state"]
+ assert 0 in migrated["state"][pid]
+
+
+# ====================================================================
+# CLI (typer) | mocked
+# ====================================================================
+
+def test_cli_dry_run(monkeypatch, runner, tmp_path):
+ ckpt = tmp_path / "ckpt.pt"
+ ckpt.touch()
+ container = {"state": {}, "param_groups": []}
+ checkpoint = {"optimizer": container}
+
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: checkpoint)
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {"ok": True}, "param_groups": []})
+ monkeypatch.setattr(migrate_script.torch, "save", lambda *a, **kw: pytest.fail("save during dry-run"))
+
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--dry-run"])
assert result.exit_code == 0
assert "Dry run" in result.stdout
- assert state_container == {"state": {"migrated": True}, "param_groups": []}
+ assert container["state"] == {"ok": True}
def test_cli_writes_output(monkeypatch, runner, tmp_path):
- checkpoint_path = tmp_path / "source.pt"
- checkpoint_path.touch()
- output_path = tmp_path / "out.pt"
-
- state_container = {"state": {"initial": True}, "param_groups": ["group"]}
- checkpoint = {"optimizer": state_container}
- migrated = {"state": {"migrated": True}, "param_groups": []}
+ ckpt = tmp_path / "source.pt"
+ ckpt.touch()
+ out = tmp_path / "out.pt"
+ migrated = {"state": {"done": True}, "param_groups": []}
+ checkpoint = {"optimizer": {"state": {}, "param_groups": []}}
+ saved = {}
- def fake_load(path, map_location=None):
- return checkpoint
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: checkpoint)
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: migrated)
+ monkeypatch.setattr(migrate_script.torch, "save", lambda obj, path: saved.update(obj=obj, path=pathlib.Path(path)))
- def fake_migrate(state, _):
- return migrated
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--output", str(out)])
+ assert result.exit_code == 0
+ assert saved["path"] == out
+ assert saved["obj"]["optimizer"] == migrated
- monkeypatch.setattr(migrate_script.torch, "load", fake_load)
- monkeypatch.setattr(migrate_script, "migrate_state_dict", fake_migrate)
+def test_cli_overwrites_input_by_default(monkeypatch, runner, tmp_path):
+ ckpt = tmp_path / "ckpt.pt"
+ ckpt.touch()
saved = {}
- def fake_save(obj, path):
- saved["obj"] = obj
- saved["path"] = pathlib.Path(path)
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: {"optimizer": {"state": {}, "param_groups": []}})
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {}, "param_groups": []})
+ monkeypatch.setattr(migrate_script.torch, "save", lambda obj, path: saved.update(path=pathlib.Path(path)))
- monkeypatch.setattr(migrate_script.torch, "save", fake_save)
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock"])
+ assert result.exit_code == 0
+ assert saved["path"] == ckpt
- result = runner.invoke(
- migrate_script.app,
- [str(checkpoint_path), "heavyball.Mock", "--output", str(output_path)],
- )
+@pytest.mark.parametrize("key", ["optimizer", "model.optimizer", "a.b.c"])
+def test_cli_state_key_parsing(monkeypatch, runner, tmp_path, key):
+ ckpt = tmp_path / "ckpt.pt"
+ ckpt.touch()
+ parts = key.split(".")
+ inner = {"state": {}, "param_groups": []}
+ root = inner
+ for p in reversed(parts):
+ root = {p: root}
+
+ monkeypatch.setattr(migrate_script.torch, "load", lambda *a, **kw: root)
+ monkeypatch.setattr(migrate_script, "migrate_state_dict", lambda s, _: {"state": {}, "param_groups": []})
+ monkeypatch.setattr(migrate_script.torch, "save", lambda *a, **kw: None)
+
+ result = runner.invoke(migrate_script.app, [str(ckpt), "heavyball.Mock", "--state-key", key])
assert result.exit_code == 0
- assert saved["path"] == output_path
- assert saved["obj"]["optimizer"] == migrated
-def test_cli_migrates_legacy_checkpoint(runner, tmp_path):
- package_root = pathlib.Path(__file__).resolve().parents[1]
- heavyball_pkg = package_root / "heavyball"
- saved_heavyball_modules = {
- name: sys.modules[name] for name in list(sys.modules) if name == "heavyball" or name.startswith("heavyball.")
- }
- for name in list(sys.modules):
- if name == "heavyball" or name.startswith("heavyball."):
- sys.modules.pop(name)
+# ====================================================================
+# CLI - real end-to-end (fresh heavyball import)
+# ====================================================================
- spec = importlib.util.spec_from_file_location(
- "heavyball",
- heavyball_pkg / "__init__.py",
- submodule_search_locations=[str(heavyball_pkg)],
- )
- heavyball_module = importlib.util.module_from_spec(spec)
- sys.modules["heavyball"] = heavyball_module
- spec.loader.exec_module(heavyball_module)
- try:
- checkpoint_path = tmp_path / "legacy.pt"
- output_path = tmp_path / "migrated.pt"
-
- legacy_state = {
- "state": {
- 0: {
- "update_by_adam_exp_avg": torch.ones((2, 2), dtype=torch.float32),
- "update_by_adam_exp_avg_sq": torch.full((2, 2), 2.0, dtype=torch.float32),
- "is_initialized": [0],
- },
- 1: {
- "update_by_adam_exp_avg": torch.ones((2,), dtype=torch.float32),
- "update_by_adam_exp_avg_sq": torch.full((2,), 2.0, dtype=torch.float32),
- "is_initialized": [0],
- },
- },
- "param_groups": [
- {
- "params": [0, 1],
- "lr": 0.0025,
- "betas": [0.9, 0.99],
- "eps": 1e-8,
- "weight_decay": 0.0,
- "warmup_steps": 0,
- "multi_tensor": True,
- "storage_dtype": "float32",
- "mars": False,
- "caution": False,
- "mars_gamma": 0.0025,
- "gradient_clipping": "use_default",
- "update_clipping": "use_default",
- "palm": "use_default",
- "beta2_scale": 0.8,
- "__class__": "heavyball.AdamW",
- }
- ],
- }
+def _make_v1_checkpoint(path, shapes, *, group_overrides=None):
+ g = _v1_group(list(range(len(shapes))))
+ g["multi_tensor"] = True
+ g.pop("foreach")
+ g.pop("stochastic_schedule")
+ if group_overrides:
+ g.update(group_overrides)
+ torch.save({"optimizer": {"state": _flat_state(*shapes), "param_groups": [g]}}, path)
+
+
+def _make_v2_checkpoint(path, shapes, *, group_overrides=None):
+ g = _v2_group(list(range(len(shapes))))
+ if group_overrides:
+ g.update(group_overrides)
+ torch.save({"optimizer": {"state": _nested_state(*shapes), "param_groups": [g], "heavyball": _v2_meta()}}, path)
- torch.save({"optimizer": legacy_state}, checkpoint_path)
- result = runner.invoke(
- migrate_script.app,
- [str(checkpoint_path), "heavyball.AdamW", "--output", str(output_path)],
- )
+def _make_v3_checkpoint(path, shapes, *, group_overrides=None):
+ g = _v3_group(list(range(len(shapes))))
+ if group_overrides:
+ g.update(group_overrides)
+ torch.save({"optimizer": {"state": _nested_state(*shapes), "param_groups": [g], "heavyball": _v3_meta()}}, path)
+
+def _run_cli_e2e(runner, ckpt, out, class_name):
+ saved = _load_heavyball_fresh()
+ try:
+ result = runner.invoke(migrate_script.app, [str(ckpt), class_name, "--output", str(out)])
assert result.exit_code == 0, result.stderr or result.stdout
- assert output_path.exists()
-
- migrated = torch.load(output_path)
- migrated_state = migrated["optimizer"]
-
- for pid, shape in [(0, (2, 2)), (1, (2,))]:
- assert pid in migrated_state["state"], f"missing state for parameter {pid}"
- migrated_bucket = migrated_state["state"][pid]
- assert 0 in migrated_bucket, f"missing transformed view for parameter {pid}"
- view_state = migrated_bucket[0]
- assert view_state["update_by_adam_exp_avg_0"].shape == shape
- assert view_state["update_by_adam_exp_avg_sq_0"].shape == shape
- assert torch.allclose(view_state["update_by_adam_exp_avg_0"], torch.ones(shape))
- assert torch.allclose(view_state["update_by_adam_exp_avg_sq_0"], torch.full(shape, 2.0))
- assert view_state["is_initialized"] == [0]
-
- assert "heavyball" in migrated_state
- assert "Migrated checkpoint written to" in result.stdout
+ return torch.load(out)["optimizer"]
finally:
- for name in list(sys.modules):
- if name == "heavyball" or name.startswith("heavyball."):
- sys.modules.pop(name)
- sys.modules.update(saved_heavyball_modules)
+ _restore_heavyball(saved)
+
+
+# --- v1 end-to-end ---
+
+
+def test_e2e_v1_adamw(runner, tmp_path):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [(2, 2), (2,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ for pid, shape in [(0, (2, 2)), (1, (2,))]:
+ view = migrated["state"][pid][0]
+ assert view["update_by_adam_exp_avg_0"].shape == shape
+ assert view["update_by_adam_exp_avg_sq_0"].shape == shape
+ assert torch.allclose(view["update_by_adam_exp_avg_0"], torch.ones(shape))
+ assert view["is_initialized"] == [0]
+ assert "heavyball" in migrated
+
+
+@pytest.mark.parametrize("old_name,new_name", _DIRECT_RENAMES_NO_LRA)
+def test_e2e_v1_all_direct_renames(runner, tmp_path, old_name, new_name):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}")
+ assert 0 in migrated["state"]
+ assert "heavyball" in migrated
+
+
+@pytest.mark.parametrize("old_name,new_name", _DELETED_RENAMES_NO_LRA)
+def test_e2e_v1_all_deleted_renames(runner, tmp_path, old_name, new_name):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}")
+ assert 0 in migrated["state"]
+ assert "heavyball" in migrated
+
+
+# PSGDLRA v1 e2e: PSGDLRA computes rank from param numel during __init__,
+# which the migration script's _instantiate_optimizer can't replicate from
+# a state dict alone (rank isn't stored in param_groups). The rename
+# resolution and migration logic for LRA is covered by the unit tests above.
+
+
+@pytest.mark.parametrize("n_params", [1, 3, 5])
+def test_e2e_v1_varying_param_count(runner, tmp_path, n_params):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ shapes = [(i + 2,) for i in range(n_params)]
+ _make_v1_checkpoint(ckpt, shapes)
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ for pid in range(n_params):
+ assert pid in migrated["state"]
+
+
+@pytest.mark.parametrize("shape", [(4,), (2, 3), (2, 3, 4)])
+def test_e2e_v1_varying_shapes(runner, tmp_path, shape):
+ ckpt, out = tmp_path / "v1.pt", tmp_path / "out.pt"
+ _make_v1_checkpoint(ckpt, [shape])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ assert migrated["state"][0][0]["update_by_adam_exp_avg_0"].shape == shape
+
+
+# --- v2 end-to-end ---
+
+
+def test_e2e_v2_basic(runner, tmp_path):
+ ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt"
+ _make_v2_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ group = migrated["param_groups"][0]
+ assert group["multi_tensor"] is True
+ assert "foreach" not in group
+ assert "stochastic_schedule" not in group
+ hb = migrated["heavyball"]
+ assert "precond_rng" not in hb
+ assert "stochastic_schedule" not in hb
+
+
+@pytest.mark.parametrize("old_name,new_name", _DIRECT_RENAMES_NO_LRA)
+def test_e2e_v2_all_direct_renames(runner, tmp_path, old_name, new_name):
+ ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt"
+ _make_v2_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, f"heavyball.{old_name}")
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+ assert "foreach" not in migrated["param_groups"][0]
+
+
+@pytest.mark.parametrize("n_params", [1, 3, 5])
+def test_e2e_v2_varying_param_count(runner, tmp_path, n_params):
+ ckpt, out = tmp_path / "v2.pt", tmp_path / "out.pt"
+ shapes = [(i + 2,) for i in range(n_params)]
+ _make_v2_checkpoint(ckpt, shapes)
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ for pid in range(n_params):
+ assert pid in migrated["state"]
+
+
+# --- v3 end-to-end (noop) ---
+
+
+def test_e2e_v3_noop(runner, tmp_path):
+ ckpt, out = tmp_path / "v3.pt", tmp_path / "out.pt"
+ _make_v3_checkpoint(ckpt, [(4,)])
+ migrated = _run_cli_e2e(runner, ckpt, out, "heavyball.AdamW")
+ assert migrated["param_groups"][0]["multi_tensor"] is True
+ assert torch.equal(migrated["state"][0][0]["key_0"], torch.zeros(4))
diff --git a/test/test_param_ecc_compile.py b/test/test_param_ecc_compile.py
index e6fc4ce..6b8c462 100644
--- a/test/test_param_ecc_compile.py
+++ b/test/test_param_ecc_compile.py
@@ -108,7 +108,7 @@ def test_ecc_populated_rne():
p, ecc, _ = _train_linear("max-autotune-no-cudagraphs", rne=True)
assert ecc is not None, "param::ecc not found in optimizer state"
assert ecc.any(), (
- "param::ecc all zeros with RNE rounding under compile — Inductor likely folded the f32->bf16->f32 roundtrip"
+ "param::ecc all zeros with RNE rounding under compile, Inductor likely folded the f32->bf16->f32 roundtrip"
)
clean()
diff --git a/test/test_toy_training.py b/test/test_toy_training.py
index 0e2e83a..6a729af 100644
--- a/test/test_toy_training.py
+++ b/test/test_toy_training.py
@@ -29,10 +29,8 @@ def _flatten_tensors(tensors: Iterable[torch.Tensor]):
def _optimizer_params():
params = []
- for name in sorted(dir(heavyball)):
- if name.startswith("_"):
- continue
- attr = getattr(heavyball, name)
+ for name in sorted(heavyball.__all__):
+ attr = getattr(heavyball, name, None)
if not isinstance(attr, type) or not issubclass(attr, optim.Optimizer):
continue
if attr is optim.Optimizer:
diff --git a/test/utils.py b/test/utils.py
index 163445f..b00ff8c 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -34,7 +34,7 @@ def _deduplicate_by_chain(names):
"""Keep one optimizer per unique chain of functions.
Two optimizers that differ only by multi_tensor=True/False have identical
- chains and test the same code paths — keep whichever appears first.
+ chains and test the same code paths, keep whichever appears first.
"""
seen = set()
out = []
From 6f43bf5ad314fff8c7616b433f9882ffb30d5848 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Mon, 30 Mar 2026 08:47:40 +0000
Subject: [PATCH 08/24] improve interface
---
heavyball/chainable.py | 153 ++++++++++++++++++++------------------
heavyball/utils.py | 37 +++++----
test/test_cpu_features.py | 80 ++++++++++++++++++++
3 files changed, 183 insertions(+), 87 deletions(-)
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index b37e43d..abd0290 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -277,45 +277,50 @@ def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
class PrecondGradAccumGuard(FunctionTransform):
def __init__(self, fn):
super().__init__(fn, ["precond_grad_accum"])
- self.steps_taken = 0
- self.pass_through = None
+ self.steps_taken_key = None
- def _accum(self, state, new):
- self.steps_taken += 1
+ def _build_val_names(self):
+ super()._build_val_names()
+ self.steps_taken_key = f"_{self.fn_name}_steps_taken_{self.transform_idx}"
+
+ def _accum(self, group, state, new):
+ group[self.steps_taken_key] = group.get(self.steps_taken_key, 0) + 1
utils.stochastic_add_(state, new)
- def _reset(self, state):
- if self.steps_taken != 0:
- self.steps_taken = 0
+ def _reset(self, group, state):
+ if group.get(self.steps_taken_key, 0) != 0:
+ group[self.steps_taken_key] = 0
utils.zero_(state)
def _init(self, state: dict, group: dict, update: Tensor, grad: Tensor, param: Tensor, *args, **kwargs):
- if self.pass_through is None:
- self.pass_through = not group.get("precond_grad_accum", False)
- if self.pass_through is False:
- for name in self.names:
- _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
+ if not group.get("precond_grad_accum", False):
+ return
+ for name in self.names:
+ _zero_guard(state, self.val_name(name), param, _storage_dtype(group))
def _call(self, state, group, update, grad, param, vars, *args, **kwargs):
base_grad = update if group.get("momentum_into_precond_update", True) else grad
- if self.pass_through:
+ if not group.get("precond_grad_accum", False):
return self.fn(state, group, update, grad, param, *args, base_grad, **kwargs)
(vars,) = vars
+ steps_taken = group.get(self.steps_taken_key, 0)
+ accum_state = None
if group["is_preconditioning"]:
- if self.steps_taken:
- self._accum(vars, base_grad)
- utils.stochastic_multiply_(vars, 1 / self.steps_taken)
+ if steps_taken:
+ self._accum(group, vars, base_grad)
+ utils.stochastic_multiply_(vars, 1 / group[self.steps_taken_key])
+ accum_state = vars
else:
vars = base_grad
else:
- self._accum(vars, base_grad)
+ self._accum(group, vars, base_grad)
vars = base_grad
try:
out = self.fn(state, group, update, grad, param, *args, vars, **kwargs)
finally:
- if group["is_preconditioning"]:
- self._reset(vars)
+ if accum_state is not None:
+ self._reset(group, accum_state)
return out
@@ -1481,6 +1486,12 @@ def palm_beta2(state, group, update, grad, param):
def apply_to_idx(fn, idx):
+ name = fn
+ if isinstance(fn, str):
+ fn = getattr(utils, fn, None)
+ if fn is None or not callable(fn):
+ raise ValueError(f"Unknown function '{name}'")
+
def _fn(state, group, update, grad, param):
args = [state, group, update, grad, param]
return fn(args[idx])
@@ -1859,15 +1870,26 @@ def _step_inner(self, group):
return
p, g = zip(*vals)
- for param in p:
- state = self.state_(param)
- step = state.get("step", 0)
- if not isinstance(step, torch.Tensor):
- step = torch.tensor(step, dtype=torch.int64, device=param.device)
- state["step"] = step
- break
-
- group["step"] = state["step"] = step = step + 1
+ step = group.get("_group_step")
+ if step is None:
+ for param in group["params"]:
+ param_state = self.state.get(param)
+ if not isinstance(param_state, dict):
+ continue
+ for idx_state in param_state.values():
+ if isinstance(idx_state, dict) and "step" in idx_state:
+ step = idx_state["step"]
+ break
+ if step is not None:
+ break
+ else:
+ step = 0
+ if isinstance(step, torch.Tensor):
+ step = step.to(device=p[0].device, dtype=torch.int64)
+ else:
+ step = torch.tensor(step, dtype=torch.int64, device=p[0].device)
+ group["_group_step"] = group["step"] = step = step + 1
+ self.state_(p[0])["step"] = step
group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
if not group["multi_tensor"] or len(p) == 1:
@@ -2024,55 +2046,44 @@ def __init__(
class ScheduleFree(BaseOpt):
def eval(self):
+ return self.train(False)
+
+ def train(self, mode: bool = True):
z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode", True)
- beta1 = utils.get_beta1(group)
- if beta1 > 0 and not train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if z_key in state:
- z = utils.promote(state[z_key])
- p32 = utils.promote(p.data)
- p32.lerp_(end=z, weight=1 - 1 / beta1)
- utils.copy_stochastic_(p.data, p32)
-
- def train(self):
- z_key = self._find_val_name("z")
- for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode", False)
+ train_mode = group.get("train_mode", True)
+ if train_mode == mode:
+ continue
+ group["train_mode"] = mode
beta1 = utils.get_beta1(group)
- if beta1 > 0 and train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if z_key in state:
- z = utils.promote(state[z_key])
- p32 = utils.promote(p.data)
- p32.lerp_(end=z, weight=1 - beta1)
- utils.copy_stochastic_(p.data, p32)
+ if beta1 <= 0:
+ continue
+ weight = 1 - beta1 if mode else 1 - 1 / beta1
+ for p in group["params"]:
+ state = self.state_(p)
+ if z_key in state:
+ z = utils.promote(state[z_key])
+ p32 = utils.promote(p.data)
+ p32.lerp_(end=z, weight=weight)
+ utils.copy_stochastic_(p.data, p32)
+ return self
class MSAM(BaseOpt):
def eval(self):
+ return self.train(False)
+
+ def train(self, mode: bool = True):
z_key = self._find_val_name("z")
for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode", True)
- if not train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if z_key in state:
- p_copy = p.data.clone()
- utils.copy_stochastic_(p.data, state[z_key])
- utils.copy_stochastic_(state[z_key], p_copy)
-
- def train(self):
- z_key = self._find_val_name("z")
- for group in self.param_groups:
- group["train_mode"] = train_mode = not group.get("train_mode", False)
- if train_mode:
- for p in group["params"]:
- state = self.state_(p)
- if z_key in state:
- p_copy = p.data.clone()
- utils.copy_stochastic_(p.data, state[z_key])
- utils.copy_stochastic_(state[z_key], p_copy)
+ train_mode = group.get("train_mode", True)
+ if train_mode == mode:
+ continue
+ group["train_mode"] = mode
+ for p in group["params"]:
+ state = self.state_(p)
+ if z_key in state:
+ p_copy = p.data.clone()
+ utils.copy_stochastic_(p.data, state[z_key])
+ utils.copy_stochastic_(state[z_key], p_copy)
+ return self
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 44b4b28..e8021a0 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -1590,6 +1590,19 @@ def _handle_closure(self, closure):
self._fallback_enabled = True
return self._handle_closure(closure)
+ def _cleanup_temporary_tensors(self):
+ for real, views in self.mapping.items():
+ fmt = getattr(real, "_restore_memory_format", None)
+ if fmt is not None:
+ del self.mapping_inverse[_tensor_key(real)]
+ real.data = real.data.to(memory_format=fmt)
+ self.mapping_inverse[_tensor_key(real)] = (real, 0)
+ del real._restore_memory_format
+ for tensor in (real, *views):
+ for key in ("grad", "vector", "hessian_vector", "orig"):
+ if hasattr(tensor, key):
+ setattr(tensor, key, None)
+
def step(self, closure: Optional[Callable] = None):
if self.precond_schedule is None:
self._is_preconditioning = False
@@ -1601,22 +1614,14 @@ def step(self, closure: Optional[Callable] = None):
self.ema_update()
# we assume that parameters are constant and that there are no excessive recompiles
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
- for group in self.param_groups:
- if "param_count" not in group:
- group["param_count"] = sum(p.numel() for p in group["params"])
- group["is_preconditioning"] = self._is_preconditioning
- self._step(group)
- for real, views in self.mapping.items():
- fmt = getattr(real, "_restore_memory_format", None)
- if fmt is not None:
- del self.mapping_inverse[_tensor_key(real)]
- real.data = real.data.to(memory_format=fmt)
- self.mapping_inverse[_tensor_key(real)] = (real, 0)
- del real._restore_memory_format
- for tensor in (real, *views):
- for key in ("grad", "vector", "hessian_vector", "orig"):
- if hasattr(tensor, key):
- setattr(tensor, key, None)
+ try:
+ for group in self.param_groups:
+ if "param_count" not in group:
+ group["param_count"] = sum(p.numel() for p in group["params"])
+ group["is_preconditioning"] = self._is_preconditioning
+ self._step(group)
+ finally:
+ self._cleanup_temporary_tensors()
return loss
diff --git a/test/test_cpu_features.py b/test/test_cpu_features.py
index 82ed343..0a8b746 100644
--- a/test/test_cpu_features.py
+++ b/test/test_cpu_features.py
@@ -132,3 +132,83 @@ def closure():
after = [param.detach() for param in model.parameters()]
diff = torch.cat([(a - b).reshape(-1) for a, b in zip(after, before, strict=True)])
assert diff.norm().item() > 0.0
+
+
+def test_multiple_param_groups_keep_updating() -> None:
+ p1 = nn.Parameter(torch.zeros(()))
+ p2 = nn.Parameter(torch.zeros(()))
+ opt = heavyball.SGD(
+ [
+ {"params": [p1]},
+ {"params": [p2]},
+ ],
+ lr=0.1,
+ beta=0.0,
+ warmup_steps=0,
+ )
+
+ for _ in range(3):
+ p1.grad = torch.ones_like(p1)
+ p2.grad = torch.full_like(p2, 2.0)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ assert torch.allclose(p1.detach(), torch.tensor(-0.3))
+ assert torch.allclose(p2.detach(), torch.tensor(-0.6))
+
+
+def test_group_step_does_not_reset_when_active_param_changes() -> None:
+ p1 = nn.Parameter(torch.zeros(()))
+ p2 = nn.Parameter(torch.zeros(()))
+ opt = heavyball.SGD([p1, p2], lr=0.1, beta=0.0, warmup_steps=3)
+
+ p1.grad = torch.ones_like(p1)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ p2.grad = torch.ones_like(p2)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ assert p1.item() == pytest.approx(-0.025)
+ assert p2.item() == pytest.approx(-0.05)
+
+
+def test_string_clipping_shorthands_match_public_api() -> None:
+ model, data, target = _make_batch()
+ opt = heavyball.SGD(
+ model.parameters(),
+ lr=1e-3,
+ beta=0.0,
+ gradient_clipping="l2_clip_",
+ update_clipping="trust_region_clip_",
+ )
+
+ loss = _train_once(opt, model, data, target, steps=2)
+ assert torch.isfinite(torch.tensor(loss))
+
+
+@pytest.mark.parametrize("opt_cls", [heavyball.SFAdamW, heavyball.MSAMLaProp])
+def test_mode_switches_are_idempotent(opt_cls) -> None:
+ p = nn.Parameter(torch.tensor([1.0, -1.0]))
+ opt = opt_cls([p], lr=1e-2)
+
+ p.grad = torch.ones_like(p)
+ opt.step()
+ opt.zero_grad(set_to_none=True)
+
+ opt.eval()
+ eval_once = p.detach().clone()
+ opt.eval()
+ eval_twice = p.detach().clone()
+
+ assert torch.allclose(eval_once, eval_twice)
+ assert opt.param_groups[0]["train_mode"] is False
+
+ opt.train()
+ train_once = p.detach().clone()
+ opt.train()
+ train_twice = p.detach().clone()
+
+ assert torch.allclose(train_once, train_twice)
+ assert opt.param_groups[0]["train_mode"] is True
From 97ccdfb98ee31355ea2dd7bc177d0ddfdeb9a676 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Mon, 30 Mar 2026 11:21:15 +0000
Subject: [PATCH 09/24] migration docs
---
README.md | 7 +-
docs/heavyball3.md | 117 +++++++++++++++++++++++++++++
heavyball/__init__.py | 34 +++------
heavyball/chainable.py | 30 +++++++-
heavyball/utils.py | 5 +-
pyproject.toml | 2 +-
scripts/migrate_optimizer_state.py | 14 ++++
7 files changed, 176 insertions(+), 33 deletions(-)
create mode 100644 docs/heavyball3.md
diff --git a/README.md b/README.md
index e8bd664..d143e93 100644
--- a/README.md
+++ b/README.md
@@ -204,10 +204,11 @@ HeavyBall includes a diagnostic benchmark suite via [LightBench](https://github.
for silent optimizer failures across difficulty levels. Results and methodology are documented
in [docs/benchmark.md](docs/benchmark.md).
-## Migrating from 1.x
+## Migrating
-See the [2.0.0 migration notes](docs/heavyball2.md) for a full checklist, and `scripts/migrate_optimizer_state.py` for
-checkpoint conversion.
+**From 2.x** See the [3.0.0 migration guide](docs/heavyball3.md) for renamed classes, removed kwargs, and checkpoint conversion.
+
+**From 1.x** See the [2.0.0 migration notes](docs/heavyball2.md), then follow the 3.0.0 guide.
## Contributing
diff --git a/docs/heavyball3.md b/docs/heavyball3.md
new file mode 100644
index 0000000..f4fc2c5
--- /dev/null
+++ b/docs/heavyball3.md
@@ -0,0 +1,117 @@
+# HeavyBall 3.0.0
+
+## Highlights
+
+* Simplified public API: `Foreach*` prefixes removed, short names are now the canonical classes
+* New optimizers: `HyperBallAdamW`, `MuonAdamW`, `PSGDPRO`
+* `Route`-based param dispatch replaces manual `SplitOpt` for mixed-architecture optimizers
+* `ScheduleFree` and `MSAM` mode switches are now idempotent (`eval()` twice is safe)
+* Higher-precision PSGD preconditioner updates
+* `torch.compile`-friendly step with automatic eager fallback for init/preconditioning
+
+---
+
+## Breaking changes
+
+### Class renames
+
+Every `Foreach*` class is renamed to its short form. The old short-form aliases (which existed
+in 2.x) keep working — only the `Foreach*` imports break.
+
+| 2.x name | 3.x name |
+|---|---|
+| `ForeachAdamW` | `AdamW` |
+| `ForeachNAdam` | `NAdam` |
+| `ForeachAdEMAMix` | `AdEMAMix` |
+| `ForeachAdamC` | `AdamC` |
+| `ForeachRMSprop` | `RMSprop` |
+| `ForeachSFAdamW` | `SFAdamW` |
+| `ForeachADOPT` | `ADOPT` |
+| `ForeachMuon` | `Muon` |
+| `ForeachLaProp` | `LaProp` |
+| `ForeachSignLaProp` | `SignLaProp` |
+| `ForeachSOAP` | `SOAP` |
+| `ForeachSOAPNAdam` | `SOAPNAdam` |
+| `ForeachSOAPAdEMAMix` | `SOAPAdEMAMix` |
+| `ForeachSOLP` | `SOLP` |
+| `ForeachPSGDKron` | `PSGDKron` |
+| `ForeachPSGDLRA` | `PSGDLRA` |
+
+### Removed optimizer classes
+
+These were thin subclasses that only set a class-level default. Use the parent class with the
+corresponding constructor argument instead.
+
+| 2.x class | 3.x equivalent |
+|---|---|
+| `PaLMForeachSFAdamW` / `PaLMSFAdamW` | `SFAdamW(..., palm=True)` |
+| `PaLMForeachSOAP` / `PaLMSOAP` / `PalmForEachSoap` | `SOAP(..., palm=True)` |
+| `PrecondScheduleForeachSOAP` / `PrecondScheduleSOAP` | `SOAP(..., use_precond_schedule=True)` |
+| `PrecondSchedulePaLMForeachSOAP` / `PrecondSchedulePaLMSOAP` | `SOAP(..., palm=True, use_precond_schedule=True)` |
+| `ForeachPurePSGD` / `PurePSGD` | `PSGDKron(..., exp_avg_input=False)` |
+| `ForeachCachedPSGDKron` / `CachedPSGDKron` | `PSGDKron(...)` (caching is now the default) |
+| `ForeachDelayedPSGD` / `DelayedPSGD` | `PSGDKron(..., delayed=True)` |
+| `ForeachCachedDelayedPSGDKron` / `CachedDelayedPSGDKron` | `PSGDKron(..., delayed=True)` |
+| `ForeachCachedNewtonPSGD` / `NewtonPSGDKron` | `PSGDKron(..., hessian_approx=True)` |
+| `NewtonHybrid2PSGDKron` | `PSGDKron(..., hessian_approx=True, hvp_interval=2)` |
+| `ForeachDelayedPSGDLRA` / `DelayedPSGDLRA` | `PSGDLRA(..., delayed=True)` |
+| `ForeachNewtonPSGDLRA` / `NewtonPSGDLRA` | `PSGDLRA(..., hessian_approx=True)` |
+| `NewtonHybrid2PSGDLRA` | `PSGDLRA(..., hessian_approx=True, hvp_interval=2)` |
+
+### Renamed parameters
+
+| 2.x parameter | 3.x parameter | Notes |
+|---|---|---|
+| `foreach` | `multi_tensor` | Passing `foreach` emits a `FutureWarning` and remaps automatically |
+
+### Removed parameters
+
+These raise `TypeError` if passed. They were either unused or replaced by better defaults.
+
+| Parameter | Previously on | Notes |
+|---|---|----------------------------------------------------------|
+| `stochastic_schedule` | SOAP, PSGDKron, PSGDLRA | Deterministic accumulation schedule is now the only mode |
+| `normalize_grads` | SOAP variants | Was unused in the transform pipeline |
+| `correct_bias` | SOAP variants | Was unused in the transform pipeline |
+| `inverse_free` | PSGDKron | Use `quad_torch` or PSGDPRO for inverse-free PSGD |
+| `adaptive` | PSGDKron | Removed |
+
+### Chainable API renames
+
+| 2.x name | 3.x name |
+|---|---|
+| `Branch` | `Parallel` |
+
+### Behavioral changes
+
+* **ScheduleFree / MSAM `eval()` / `train()`**: Now idempotent. Calling `eval()` twice no
+ longer flips back to train mode. Both methods accept a `mode` argument matching
+ `nn.Module.train(mode)` and return `self`.
+* **PSGD dampening**: `dampen_grad` default changed from `2**-13` to `1e-9`, and dampening
+ epsilon uses `torch.finfo(float32).eps` regardless of input dtype. This improves
+ preconditioner accuracy but may change convergence behavior.
+
+---
+
+## Checkpoint migration
+
+Use the migration CLI to convert 1.x or 2.x checkpoints:
+
+```bash
+python scripts/migrate_optimizer_state.py
+```
+
+Old class names (including all aliases listed above) are resolved automatically.
+The `foreach` → `multi_tensor` key rename in param groups is handled automatically.
+
+---
+
+## Upgrade checklist
+
+1. Replace `from heavyball import Foreach*` with the short name (e.g., `ForeachAdamW` → `AdamW`)
+2. Replace `foreach=` with `multi_tensor=` in constructor calls
+3. Replace removed subclass instantiations with parent + kwargs (see table above)
+4. Remove any `stochastic_schedule`, `normalize_grads`, `correct_bias`, `inverse_free`, or `adaptive` kwargs
+5. Replace `Branch(...)` with `Parallel(...)` in custom chainable code
+6. Migrate checkpoints: `python scripts/migrate_optimizer_state.py heavyball.`
+7. If you relied on `eval(); eval()` toggling back to train mode, update your code
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index a7a69ee..d4357c4 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -30,7 +30,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.heavyball_momentum,))
class AdamW(C.BaseOpt):
@@ -58,7 +58,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adam,))
class NAdam(C.BaseOpt):
@@ -88,7 +88,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_nadam,))
class AdEMAMix(C.BaseOpt):
@@ -120,7 +120,7 @@ def __init__(
raise ValueError("AdEMAMix expects betas with three coefficients.")
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.update_by_ademamix,))
class UnscaledAdamW(C.BaseOpt):
@@ -149,7 +149,7 @@ def __init__(
):
params, defaults = C._build_defaults(locals())
super().__init__(
- params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
+ params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
)
@@ -179,7 +179,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_suds,))
class Scion(C.BaseOpt):
@@ -222,7 +222,7 @@ def __init__(
defaults.pop("momentum", None)
super().__init__(
- params, defaults, multi_tensor, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm)
+ params, defaults, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm)
)
@@ -258,7 +258,7 @@ def __init__(
max_lr = lr
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adamc,))
class RMSprop(C.BaseOpt):
@@ -295,7 +295,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -331,7 +330,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -372,7 +370,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -413,7 +410,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -452,7 +448,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -485,7 +480,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_adopt,))
class Muon(C.BaseOpt):
@@ -526,7 +521,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -559,7 +553,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(params, defaults, multi_tensor, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.update_by_laprop,))
class MuonLaProp(C.BaseOpt):
@@ -590,7 +584,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -613,7 +606,6 @@ def _build_soap_defaults(self, locals_dict, fns):
super().__init__(
params,
defaults,
- locals_dict["multi_tensor"],
locals_dict["gradient_clipping"],
locals_dict["update_clipping"],
locals_dict.get("palm", False),
@@ -773,7 +765,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -855,7 +846,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -891,7 +881,6 @@ def __init__(
super().__init__(
params,
defaults,
- multi_tensor,
gradient_clipping,
update_clipping,
palm,
@@ -921,7 +910,6 @@ def _build_psgd_defaults(self, locals_dict, fns, *, default_update_clipping=util
super().__init__(
params,
defaults,
- locals_dict["multi_tensor"],
locals_dict["gradient_clipping"],
update_clipping,
False,
@@ -1139,7 +1127,7 @@ def __init__(self, specs):
all_params.extend(params)
if not self.optimizers:
raise ValueError("No optimizers created")
- super().__init__(all_params, {}, multi_tensor=True)
+ super().__init__(all_params, {"multi_tensor": True})
def _step(self, group):
pass
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index abd0290..1c210c3 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -2,6 +2,7 @@
import copy
import functools
import math
+import warnings
from collections.abc import Iterable as _Iterable
from typing import Iterable, List, Literal, Optional, Union
@@ -180,12 +181,34 @@ def _storage_dtype(group):
_PASSTHROUGH_KWARGS = {"orig_shapes"}
+_RENAMED_KWARGS = {"foreach": "multi_tensor"}
+
+_REMOVED_KWARGS = frozenset({
+ "stochastic_schedule",
+ "normalize_grads",
+ "correct_bias",
+ "inverse_free",
+ "adaptive",
+})
+
def _build_defaults(locals_dict):
d = locals_dict.copy()
d.pop("self")
params = d.pop("params")
kwargs = d.pop("kwargs")
+
+ for old, new in _RENAMED_KWARGS.items():
+ if old in kwargs:
+ warnings.warn(f"'{old}' was renamed to '{new}' in HeavyBall 3.0. "
+ f"Pass '{new}' instead.", FutureWarning, stacklevel=4)
+ d[new] = kwargs.pop(old)
+
+ hit = _REMOVED_KWARGS & kwargs.keys()
+ if hit:
+ raise TypeError(f"Removed in HeavyBall 3.0: {', '.join(sorted(hit))}. "
+ f"See docs/heavyball3.md for migration details.")
+
d.update(kwargs)
unknown = {k: v for k, v in kwargs.items() if k not in _PASSTHROUGH_KWARGS}
if unknown:
@@ -1725,14 +1748,14 @@ class ChainOpt(utils.StatefulOptimizer):
"eps": 1e-8,
}
- def __init__(self, params, defaults, multi_tensor: bool, *fns):
+ def __init__(self, params, defaults, *fns):
orig = defaults.pop("orig_shapes", None)
self._orig_shapes = (
{k: _ShapeInfo(v) if isinstance(v, tuple) else v for k, v in orig.items()} if orig is not None else None
)
base = self.global_defaults.copy()
base.update({k: v for k, v in defaults.items() if v is not use_default})
- super().__init__(params, base, multi_tensor)
+ super().__init__(params, base)
self.fns = fns
self._eager_chain = self._run_chain
if self.compile_step:
@@ -1993,7 +2016,6 @@ def __init__(
self,
params,
defaults,
- multi_tensor: bool = True,
gradient_clipping: str_or_fn = None,
update_clipping: str_or_fn = None,
palm: bool = use_default,
@@ -2041,7 +2063,7 @@ def __init__(
if default(update_clipping, self.update_clipping) is not None:
fns = fns + (apply_to_idx(update_clipping, 2),)
- super().__init__(params, defaults, multi_tensor, *fns)
+ super().__init__(params, defaults, *fns)
class ScheduleFree(BaseOpt):
diff --git a/heavyball/utils.py b/heavyball/utils.py
index e8021a0..e2207a5 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -1306,13 +1306,14 @@ class StatefulOptimizer(torch.optim.Optimizer):
"hessian_approx",
)
- def __init__(self, params, defaults, multi_tensor: bool = True, use_ema: bool = False):
+ def __init__(self, params, defaults, use_ema: bool = False):
for attr in self._INSTANCE_ATTRS:
if attr in defaults:
val = defaults.pop(attr)
if val is not use_default:
setattr(self, attr, val)
- super().__init__(params, {**defaults, "multi_tensor": multi_tensor})
+ defaults.setdefault("multi_tensor", True)
+ super().__init__(params, defaults)
self.use_ema = use_ema
self.mapping = {}
self.mapping_inverse = {}
diff --git a/pyproject.toml b/pyproject.toml
index 54c4d55..7c2d78d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "heavyball"
description = "Compile-first PyTorch optimizer library - AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features"
-version = "2.3.2"
+version = "3.0.0"
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
classifiers = ["Intended Audience :: Developers",
"Intended Audience :: Science/Research",
diff --git a/scripts/migrate_optimizer_state.py b/scripts/migrate_optimizer_state.py
index 0c0e3e0..162538c 100644
--- a/scripts/migrate_optimizer_state.py
+++ b/scripts/migrate_optimizer_state.py
@@ -24,6 +24,7 @@
import typer
_CLASS_RENAMES = {
+ # Foreach* internal names → canonical 3.x names
"ForeachAdamW": "AdamW",
"ForeachNAdam": "NAdam",
"ForeachAdEMAMix": "AdEMAMix",
@@ -54,6 +55,19 @@
"ForeachDelayedPSGDLRA": "PSGDLRA",
"ForeachNewtonPSGDLRA": "PSGDLRA",
"NewtonHybrid2PSGDLRA": "PSGDLRA",
+ # 2.x public aliases that pointed to Foreach* classes (now removed)
+ "PaLMSOAP": "SOAP",
+ "PaLMSFAdamW": "SFAdamW",
+ "PalmForEachSoap": "SOAP",
+ "PrecondScheduleSOAP": "SOAP",
+ "PrecondSchedulePaLMSOAP": "SOAP",
+ "PurePSGD": "PSGDKron",
+ "DelayedPSGD": "PSGDKron",
+ "CachedPSGDKron": "PSGDKron",
+ "CachedDelayedPSGDKron": "PSGDKron",
+ "NewtonPSGDKron": "PSGDKron",
+ "DelayedPSGDLRA": "PSGDLRA",
+ "NewtonPSGDLRA": "PSGDLRA",
}
_REMOVED_GROUP_KEYS = frozenset({"stochastic_schedule"})
From cc981339e5fdcf08a8f8e0a793f4d6f510262d24 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Mon, 30 Mar 2026 11:21:39 +0000
Subject: [PATCH 10/24] ruff
---
heavyball/__init__.py | 39 ++++++++++++++++++------------
heavyball/chainable.py | 51 ++++++++++++++++++++-------------------
heavyball/utils.py | 10 ++------
test/test_compile_step.py | 4 +--
test/test_migrate_cli.py | 18 +++++++++++---
test/test_utils_cpu.py | 1 -
6 files changed, 68 insertions(+), 55 deletions(-)
diff --git a/heavyball/__init__.py b/heavyball/__init__.py
index d4357c4..46a0f7c 100644
--- a/heavyball/__init__.py
+++ b/heavyball/__init__.py
@@ -148,9 +148,7 @@ def __init__(
**kwargs,
):
params, defaults = C._build_defaults(locals())
- super().__init__(
- params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,)
- )
+ super().__init__(params, defaults, gradient_clipping, update_clipping, palm, fns=(C.scale_by_unscaled_adam,))
class SUDSAdamW(C.BaseOpt):
@@ -221,9 +219,7 @@ def __init__(
defaults["scale"] = scale
defaults.pop("momentum", None)
- super().__init__(
- params, defaults, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm)
- )
+ super().__init__(params, defaults, gradient_clipping, update_clipping, fns=(C.exp_avg, C.scion_auto_norm))
class AdamC(C.BaseOpt):
@@ -333,10 +329,13 @@ def __init__(
gradient_clipping,
update_clipping,
palm,
- fns=(C.scale_by_exp_avg_sq, C.route(
- (lambda p: p.ndim >= 2, C.update_by_hyperball),
- default=C.apply_update,
- )),
+ fns=(
+ C.scale_by_exp_avg_sq,
+ C.route(
+ (lambda p: p.ndim >= 2, C.update_by_hyperball),
+ default=C.apply_update,
+ ),
+ ),
)
@@ -373,10 +372,12 @@ def __init__(
gradient_clipping,
update_clipping,
palm,
- fns=(C.route(
- (lambda p: p.ndim >= 2, (ema, C.orthogonalize_update)),
- default=C.scale_by_adam,
- ),),
+ fns=(
+ C.route(
+ (lambda p: p.ndim >= 2, (ema, C.orthogonalize_update)),
+ default=C.scale_by_adam,
+ ),
+ ),
)
@@ -893,7 +894,9 @@ class PSGDBase(C.BaseOpt):
cached: bool = False
exp_avg_input: bool = True
- def _build_psgd_defaults(self, locals_dict, fns, *, default_update_clipping=utils.trust_region_clip_, extra_defaults=None):
+ def _build_psgd_defaults(
+ self, locals_dict, fns, *, default_update_clipping=utils.trust_region_clip_, extra_defaults=None
+ ):
exp_avg_input = C.default(locals_dict.get("exp_avg_input", C.use_default), self.exp_avg_input)
update_clipping = C.default(locals_dict["update_clipping"], default_update_clipping)
@@ -1204,4 +1207,8 @@ def zero_grad(self, set_to_none: bool = True):
capture_param_shapes = utils.capture_param_shapes
_BASE_CLASSES = {SOAPBase, PSGDBase}
-__all__ = [k for k, v in globals().items() if isinstance(v, type) and issubclass(v, torch.optim.Optimizer) and v not in _BASE_CLASSES]
+__all__ = [
+ k
+ for k, v in globals().items()
+ if isinstance(v, type) and issubclass(v, torch.optim.Optimizer) and v not in _BASE_CLASSES
+]
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index 1c210c3..c676931 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -25,7 +25,6 @@ def _key_in_state(state, key):
return True
-
def _guard_in_state(state, key, template_fn):
if not _key_in_state(state, key):
state[key] = template_fn()
@@ -150,7 +149,9 @@ def _sel(lst, idx):
continue
group["caution"] = caution
if fns is not None:
- u, skip = _inner_chain(_sel(state, idx), group, _sel(update, idx), _sel(grad, idx), _sel(param, idx), *fns)
+ u, skip = _inner_chain(
+ _sel(state, idx), group, _sel(update, idx), _sel(grad, idx), _sel(param, idx), *fns
+ )
else:
u, skip = _sel(update, idx), False
results.append((u, skip, idx))
@@ -183,13 +184,15 @@ def _storage_dtype(group):
_RENAMED_KWARGS = {"foreach": "multi_tensor"}
-_REMOVED_KWARGS = frozenset({
- "stochastic_schedule",
- "normalize_grads",
- "correct_bias",
- "inverse_free",
- "adaptive",
-})
+_REMOVED_KWARGS = frozenset(
+ {
+ "stochastic_schedule",
+ "normalize_grads",
+ "correct_bias",
+ "inverse_free",
+ "adaptive",
+ }
+)
def _build_defaults(locals_dict):
@@ -200,14 +203,16 @@ def _build_defaults(locals_dict):
for old, new in _RENAMED_KWARGS.items():
if old in kwargs:
- warnings.warn(f"'{old}' was renamed to '{new}' in HeavyBall 3.0. "
- f"Pass '{new}' instead.", FutureWarning, stacklevel=4)
+ warnings.warn(
+ f"'{old}' was renamed to '{new}' in HeavyBall 3.0. Pass '{new}' instead.", FutureWarning, stacklevel=4
+ )
d[new] = kwargs.pop(old)
hit = _REMOVED_KWARGS & kwargs.keys()
if hit:
- raise TypeError(f"Removed in HeavyBall 3.0: {', '.join(sorted(hit))}. "
- f"See docs/heavyball3.md for migration details.")
+ raise TypeError(
+ f"Removed in HeavyBall 3.0: {', '.join(sorted(hit))}. See docs/heavyball3.md for migration details."
+ )
d.update(kwargs)
unknown = {k: v for k, v in kwargs.items() if k not in _PASSTHROUGH_KWARGS}
@@ -795,7 +800,9 @@ def _adopt_warmup_2(state, group, update, grad, param, exp_avg, exp_avg_sq):
u = utils.promote(update)
easq = utils.promote(exp_avg_sq)
utils.copy_stochastic_(exp_avg, u / easq.sqrt().clamp_(min=group["eps"]))
- utils.scale_by_exp_avg_sq_([exp_avg_sq], [update], utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"])
+ utils.scale_by_exp_avg_sq_(
+ [exp_avg_sq], [update], utils.beta_debias(utils.get_beta2(group), group["step"]), group["eps"]
+ )
@zero_guard("exp_avg", "exp_avg_sq")
@@ -930,7 +937,6 @@ def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob
)
-
@needs_full_param
@no_state_no_multi_tensor
def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
@@ -1266,9 +1272,7 @@ def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
if group.get("is_cached", False) and Q_cache[0] is not None:
out = utils.precond_grad_cached_(cached_q=Q_cache, **kwargs)
else:
- out = utils.psgd_precond_grad(
- preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs
- )
+ out = utils.psgd_precond_grad(preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs)
group["caution"] = False # we already cautioned here - shouldn't do it again
return out
@@ -1285,9 +1289,7 @@ def _fused_cached_psgd_precond_grad(group, grad, param, update, Q, Q_cache):
if group.get("is_cached", False) and Q_cache[0] is not None:
utils.fused_precond_grad_cached_(cached_q=Q_cache, **kwargs)
else:
- utils.fused_psgd_precond_grad(
- preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs
- )
+ utils.fused_psgd_precond_grad(preconds=Q, store_triu_as_line=group["store_triu_as_line"], **kwargs)
def _update_lra(
@@ -1393,9 +1395,7 @@ def scale_by_delayed_psgd(
prob: Optional[callable] = None,
):
precond = _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
- _update_psgd_precond(
- cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob
- )
+ _update_psgd_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
return precond
@@ -1851,7 +1851,8 @@ def fns(self, value):
self._set_indices(retain=True)
self._needs_gather = any(getattr(ft, "needs_full_param", False) for ft in _walk_fns(self._fns))
self._transform_ids = frozenset(
- ft.transform_idx for ft in _walk_fns(self._fns)
+ ft.transform_idx
+ for ft in _walk_fns(self._fns)
if ft.transform_idx is not None and getattr(ft, "needs_init", True)
)
diff --git a/heavyball/utils.py b/heavyball/utils.py
index e2207a5..76e6da3 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -7,11 +7,10 @@
import itertools
import math
import pickle
-import random
import re
import string
import warnings
-from typing import Any, Callable, List, Literal, Optional, Tuple, Union
+from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -3172,10 +3171,7 @@ def _psgd_precond_update_(
@decorator_knowngood
-
-
@decorator
-
@decorator_knowngood
def _clip(x, norm, clip_at, eps=1e-8):
x32 = promote(x)
@@ -3564,9 +3560,7 @@ def fused_psgd_precond_grad(
store_triu_as_line: bool = False,
):
lr, decay = scalar_guard(lr, decay, param[0])
- _compilable_fused_psgd_precond_grad(
- ea, param, lr, grad, decay, caution, preconds, store_triu_as_line
- )
+ _compilable_fused_psgd_precond_grad(ea, param, lr, grad, decay, caution, preconds, store_triu_as_line)
@decorator_knowngood
diff --git a/test/test_compile_step.py b/test/test_compile_step.py
index eb7bbaf..ce0a8e5 100644
--- a/test/test_compile_step.py
+++ b/test/test_compile_step.py
@@ -48,6 +48,7 @@ def _make_model():
def _run_steps(model, optimizer, n=5, seed=0xDEADBEEF):
torch.manual_seed(seed)
for _ in range(n):
+
def closure():
optimizer.zero_grad(set_to_none=True)
data = torch.randn(4, 8, device=DEVICE)
@@ -109,6 +110,5 @@ def test_needs_init_clears(opt_name, opt_cls):
for group in opt.param_groups:
state = [opt.state_(p) for p in group["params"]]
assert not opt._needs_init(state), (
- f"{opt_name}: _needs_init stuck True after {n} steps | "
- f"compile_step will never engage"
+ f"{opt_name}: _needs_init stuck True after {n} steps | compile_step will never engage"
)
diff --git a/test/test_migrate_cli.py b/test/test_migrate_cli.py
index d3f3de7..0130305 100644
--- a/test/test_migrate_cli.py
+++ b/test/test_migrate_cli.py
@@ -216,9 +216,17 @@ def test_load_optimizer_class_invalid_raises():
("flat_neither", {"state": _flat_state((2,)), "param_groups": [{}]}, 1),
("flat_multi_group", {"state": _flat_state((2,), (3,)), "param_groups": [{"foreach": True}, {}]}, 1),
("nested_foreach", {"state": _nested_state((2,)), "param_groups": [{"foreach": True}]}, 2),
- ("nested_foreach_multi_group", {"state": _nested_state((2,), (3,)), "param_groups": [{"foreach": True}, {"foreach": True}]}, 2),
+ (
+ "nested_foreach_multi_group",
+ {"state": _nested_state((2,), (3,)), "param_groups": [{"foreach": True}, {"foreach": True}]},
+ 2,
+ ),
("nested_multi_tensor", {"state": _nested_state((2,)), "param_groups": [{"multi_tensor": True}]}, 3),
- ("nested_multi_tensor_multi_group", {"state": _nested_state((2,), (3,)), "param_groups": [{"multi_tensor": True}, {}]}, 3),
+ (
+ "nested_multi_tensor_multi_group",
+ {"state": _nested_state((2,), (3,)), "param_groups": [{"multi_tensor": True}, {}]},
+ 3,
+ ),
("empty_state_multi_tensor", {"state": {}, "param_groups": [{"multi_tensor": True}]}, 3),
("empty_state_foreach", {"state": {}, "param_groups": [{"foreach": True}]}, 1),
("empty_state_bare", {"state": {}, "param_groups": [{}]}, 1),
@@ -246,7 +254,11 @@ def test_detect_version(name, sd, expected):
("empty_group", {}, {}),
("params_only", {"params": [0]}, {}),
("all_removed", {"params": [0], "stochastic_schedule": None}, {}),
- ("kitchen_sink", {"params": [0, 1], "foreach": True, "stochastic_schedule": True, "lr": 0.01, "betas": [0.9, 0.99]}, {"multi_tensor": True, "lr": 0.01, "betas": (0.9, 0.99)}),
+ (
+ "kitchen_sink",
+ {"params": [0, 1], "foreach": True, "stochastic_schedule": True, "lr": 0.01, "betas": [0.9, 0.99]},
+ {"multi_tensor": True, "lr": 0.01, "betas": (0.9, 0.99)},
+ ),
]
diff --git a/test/test_utils_cpu.py b/test/test_utils_cpu.py
index 621056c..b9af3d7 100644
--- a/test/test_utils_cpu.py
+++ b/test/test_utils_cpu.py
@@ -1,5 +1,4 @@
import os
-import random
import warnings
from copy import deepcopy
From 164ba3ad3fba67fa64e95ee72c3316573456fd39 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Mon, 30 Mar 2026 18:57:09 +0000
Subject: [PATCH 11/24] reduce python overhead
---
README.md | 7 +-
heavyball/chainable.py | 378 +++++++++++++++++++++++++++++----------
heavyball/utils.py | 122 ++++++++-----
test/test_distributed.py | 39 +++-
4 files changed, 394 insertions(+), 152 deletions(-)
diff --git a/README.md b/README.md
index d143e93..3ae2ebd 100644
--- a/README.md
+++ b/README.md
@@ -128,9 +128,10 @@ Available modes: `bf16+8`, `bf16+16`, `fp16+8`, `fp16+16`.
HeavyBall works with both DDP and FSDP. First-order optimizers are elementwise and operate directly on FSDP shards with
no repartitioning. Second-order methods (Muon, SOAP, PSGD) need the full parameter to compute their update, so HeavyBall
-auto-detects FSDP-sharded parameters on the first step and repartitions them: each weight matrix is assigned to one rank
-in round-robin, which reconstructs the full parameter, computes the update, and broadcasts the result. This saves both
-compute and memory compared to DDP-style redundant updates, at the cost of communication.
+auto-detects FSDP-sharded parameters on the first step and repartitions them with a metadata-first `all_to_all_single`
+exchange: each weight matrix is deterministically assigned to one rank, shard metadata is exchanged up front, the owner
+reconstructs the full parameter, computes the update once, and returns the updated shards. This saves both compute and
+memory compared to DDP-style redundant updates, at the cost of communication.
```python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index c676931..5204dca 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -755,12 +755,12 @@ def orthogonalize_grad_to_param(group, update, grad, param):
def update_by_schedule_free(group, update, grad, param, z):
# Compute weight_sum once per step, not per param in no-multi_tensor mode.
if group.get("_sf_step") is not group["step"]:
- weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"]
+ weight = abs(group["lr"]) ** group["weight_lr_power"] * group["step"].clamp(min=1) ** group["r"]
group["weight_sum"] = group.get("weight_sum", 0) + weight
group["_sf_step"] = group["step"]
weight_sum = group["weight_sum"]
- weight = abs(group["lr"]) ** group["weight_lr_power"] * max(group["step"], 1) ** group["r"]
+ weight = abs(group["lr"]) ** group["weight_lr_power"] * group["step"].clamp(min=1) ** group["r"]
try:
ckp1 = weight / weight_sum
except ZeroDivisionError:
@@ -881,6 +881,7 @@ def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
+ tmp = utils.get_temporary(group, param) or {}
Q = utils.init_Q_exprs(
grad,
group["precond_init_scale"],
@@ -889,8 +890,8 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
group["max_size_triangular"],
group["min_ndim_triangular"],
group["memory_save_mode"],
- getattr(param, "hessian_vector", None),
- getattr(param, "vector", None),
+ tmp.get("hessian_vector"),
+ tmp.get("vector"),
dtype=getattr(torch, group["q_dtype"]),
)
state["Q"] = utils.triu_to_line(Q) if group["store_triu_as_line"] else Q
@@ -924,6 +925,7 @@ def _init_psgd_pro_kron(state, group, update, grad, param, cached: bool = False,
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
+ tmp = utils.get_temporary(group, param) or {}
state["U"], state["V"], state["d"] = utils.init_lra(
grad,
group["param_count"],
@@ -931,8 +933,8 @@ def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob
group["precond_init_scale_scale"],
group["precond_init_scale_power"],
group["rank"],
- getattr(param, "hessian_vector", None),
- getattr(param, "vector", None),
+ tmp.get("hessian_vector"),
+ tmp.get("vector"),
dtype=getattr(torch, group["q_dtype"]),
)
@@ -1169,12 +1171,10 @@ def _update_psgd_precond(
if not group["is_preconditioning"]:
return
- if utils.hasattr_none(param, "vector"):
- vector, hessian_vector = param.vector, param.hessian_vector
- del param.vector
- del param.hessian_vector
- else:
+ if (utils.get_temporary(group, param) or {}).get("vector") is None:
vector, hessian_vector = utils.dampen_grad(grad, group["dampening"])
+ else:
+ vector, hessian_vector = utils.take_temporary(group, param, "vector", "hessian_vector")
precond = utils.psgd_update_precond(
hessian_vector,
@@ -1298,12 +1298,10 @@ def _update_lra(
if not group["is_preconditioning"]:
return utils.multi_flatten((U, 1), (V, 1), (d, 0))
- if utils.hasattr_none(params[0], "hessian_vector"):
- vector = utils.flatten([p.vector for p in params])
- hessian_vector = utils.flatten([p.hessian_vector for p in params])
- for p in params:
- del p.vector
- del p.hessian_vector
+ if (utils.get_temporary(group, params[0]) or {}).get("hessian_vector") is not None:
+ vector_hv = [utils.take_temporary(group, p, "vector", "hessian_vector") for p in params]
+ vector = utils.flatten([v for v, _ in vector_hv])
+ hessian_vector = utils.flatten([hv for _, hv in vector_hv])
else:
vector, hessian_vector = utils.dampen_multiple(grads)
precond_step = group["precond_step"] = group.get("precond_step", -1) + 1
@@ -1523,15 +1521,70 @@ def _fn(state, group, update, grad, param):
return _fn
+_FSDP_HEADER_WIDTH = 4
+_FSDP_BUCKET_BYTES = 32 << 20
+_FSDP_DTYPE_CODES = {
+ torch.float64: 0,
+ torch.float32: 1,
+ torch.float16: 2,
+ torch.bfloat16: 3,
+ torch.int64: 4,
+ torch.int32: 5,
+ torch.int16: 6,
+ torch.int8: 7,
+ torch.uint8: 8,
+ torch.bool: 9,
+}
+
+
class _ShapeInfo:
- __slots__ = ("orig_shape", "offset", "total", "group", "owner")
+ __slots__ = ("orig_shape", "offset", "total", "group", "owner", "param_idx")
- def __init__(self, orig_shape, offset=0, total=None, group=None, owner=None):
+ def __init__(self, orig_shape, offset=0, total=None, group=None, owner=None, param_idx=None):
self.orig_shape = orig_shape
self.offset = offset
self.total = total if total is not None else math.prod(orig_shape)
self.group = group
self.owner = owner
+ self.param_idx = param_idx
+
+
+class _FSDPBucket:
+ __slots__ = ("device", "dtype", "send_entries", "send_splits", "recv_entries", "recv_splits")
+
+ def __init__(self, device, dtype, send_entries, send_splits, recv_entries, recv_splits):
+ self.device = device
+ self.dtype = dtype
+ self.send_entries = send_entries
+ self.send_splits = send_splits
+ self.recv_entries = recv_entries
+ self.recv_splits = recv_splits
+
+
+class _FSDPState:
+ __slots__ = ("items", "buckets")
+
+ def __init__(self, items, buckets):
+ self.items = items
+ self.buckets = buckets
+
+
+def _dtype_code(dtype):
+ if dtype not in _FSDP_DTYPE_CODES:
+ raise TypeError(f"Unsupported FSDP shard dtype: {dtype}")
+ return _FSDP_DTYPE_CODES[dtype]
+
+
+def _assign_fsdp_owners(entries, shard_sizes, world_size):
+ loads = [0] * world_size
+ owners = []
+ for i, (p, _, total, _) in enumerate(entries):
+ active = shard_sizes[i].nonzero().squeeze(-1).tolist()
+ candidates = active or list(range(world_size))
+ owner = min(candidates, key=loads.__getitem__)
+ loads[owner] += total * p.element_size()
+ owners.append(owner)
+ return owners
def _detect_orig_shapes(params):
@@ -1554,8 +1607,6 @@ def _detect_orig_shapes(params):
for param, spi, shape in zip(obj._params, obj._shard_param_infos, obj._shapes):
lookup[id(param)] = (tuple(shape), spi)
- _dist = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
-
# optimizer param order is stable across ranks
fsdp_entries = [
(p, s, math.prod(s), spi)
@@ -1563,47 +1614,178 @@ def _detect_orig_shapes(params):
for s, spi in [lookup.get(id(p), (None, None))]
if id(p) in fsdp_ids and s is not None
]
-
- groups = {}
- if _dist and fsdp_entries:
- rank = torch.distributed.get_rank()
- ws = torch.distributed.get_world_size()
- n = len(fsdp_entries)
- flags = torch.zeros(n, ws, dtype=torch.int32, device=fsdp_entries[0][0].device)
- for i, (p, orig, total, spi) in enumerate(fsdp_entries):
- if spi.in_shard and spi.numel_in_shard is not None and spi.numel_in_shard < total:
- flags[i, rank] = 1
- torch.distributed.all_reduce(flags)
- pg_cache = {}
- for i in range(n):
- ranks = flags[i].nonzero().squeeze(-1).tolist()
- if len(ranks) >= 2:
- key = tuple(ranks)
- if key not in pg_cache:
- pg_cache[key] = torch.distributed.new_group(ranks)
- groups[i] = (pg_cache[key], ranks)
-
- # owner must be precomputed (different ranks see different subsets in _reshape_params)
result = {}
- split_idx = 0
- for i, (p, orig, total, spi) in enumerate(fsdp_entries):
- sg = groups.get(i)
- if sg:
- pg, ranks = sg
- owner = ranks[split_idx % len(ranks)]
- split_idx += 1
- if spi.in_shard:
- result[id(p)] = _ShapeInfo(orig, spi.intra_param_start_idx, total, pg, owner)
- elif spi.in_shard:
- result[id(p)] = _ShapeInfo(orig, spi.intra_param_start_idx, total)
+ ws = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
+ if ws > 1 and fsdp_entries:
+ rank = torch.distributed.get_rank()
+ shard_sizes = torch.zeros(len(fsdp_entries), ws, dtype=torch.int64, device=fsdp_entries[0][0].device)
+ for i, (p, _, _, spi) in enumerate(fsdp_entries):
+ shard_sizes[i, rank] = p.numel() if spi.in_shard else 0
+ torch.distributed.all_reduce(shard_sizes)
+ owners = _assign_fsdp_owners(fsdp_entries, shard_sizes, ws)
+ else:
+ owners = [None] * len(fsdp_entries)
+ for param_idx, ((p, orig, total, spi), owner) in enumerate(zip(fsdp_entries, owners)):
+ offset = 0 if spi.intra_param_start_idx is None else spi.intra_param_start_idx
+ result[id(p)] = _ShapeInfo(orig, offset, total, owner=owner, param_idx=param_idx)
return result
-def _reduce_gather(shard, offset, total, pg, dst):
- full = shard.new_zeros(total)
- full[offset : offset + shard.numel()] = shard
- torch.distributed.reduce(full, dst=dst, group=pg)
- return full
+def _exchange_split_sizes(splits, device):
+ send = torch.tensor(splits, dtype=torch.int64, device=device)
+ recv = torch.empty_like(send)
+ torch.distributed.all_to_all_single(recv, send)
+ return recv.tolist()
+
+
+def _all_to_all_variable(sendbuf, recv_splits, send_splits):
+ recv = sendbuf.new_empty(sum(recv_splits))
+ torch.distributed.all_to_all_single(recv, sendbuf, output_split_sizes=recv_splits, input_split_sizes=send_splits)
+ return recv
+
+
+def _fsdp_bucket_schedule(items):
+ buckets, current, lookup = [], {}, {}
+ for p, info, _ in items:
+ key = (p.device, p.dtype)
+ idx = current.get(key)
+ size = info.total * p.element_size()
+ if idx is None or (buckets[idx][2] and buckets[idx][2] + size > _FSDP_BUCKET_BYTES):
+ idx = len(buckets)
+ buckets.append([p.device, p.dtype, 0])
+ current[key] = idx
+ buckets[idx][2] += size
+ lookup[info.param_idx] = idx
+ return [(device, dtype) for device, dtype, _ in buckets], lookup
+
+
+def _exchange_fsdp_shards(schedule, bucket_lookup, items, tensor_getter, keep_state=False):
+ ws = torch.distributed.get_world_size()
+ per_bucket = [[] for _ in schedule]
+ for p, info, shard in items:
+ tensor = tensor_getter(p, shard)
+ if tensor is None or tensor.numel() == 0:
+ continue
+ flat = tensor.reshape(-1)
+ bucket_idx = bucket_lookup[info.param_idx]
+ device, dtype = schedule[bucket_idx]
+ if flat.device != device or flat.dtype != dtype:
+ raise RuntimeError(f"FSDP bucket mismatch for param {info.param_idx}: expected {(device, dtype)}, got {(flat.device, flat.dtype)}")
+ per_bucket[bucket_idx].append((info.owner, info.param_idx, info.offset, flat, shard))
+
+ received, states = {}, []
+ for (device, dtype), bucket_entries in zip(schedule, per_bucket):
+ by_dst = [[] for _ in range(ws)]
+ for entry in bucket_entries:
+ by_dst[entry[0]].append(entry)
+
+ send_meta_splits = [len(dst_entries) * _FSDP_HEADER_WIDTH for dst_entries in by_dst]
+ send_payload_splits = [sum(flat.numel() for _, _, _, flat, _ in dst_entries) for dst_entries in by_dst]
+ recv_meta_splits = _exchange_split_sizes(send_meta_splits, device)
+ recv_payload_splits = _exchange_split_sizes(send_payload_splits, device)
+
+ code = _dtype_code(dtype)
+ meta = [
+ value
+ for dst_entries in by_dst
+ for _, param_idx, offset, flat, _ in dst_entries
+ for value in (param_idx, offset, flat.numel(), code)
+ ]
+ payload = [flat for dst_entries in by_dst for _, _, _, flat, _ in dst_entries]
+ send_meta = torch.tensor(meta, dtype=torch.int64, device=device) if meta else torch.empty(0, dtype=torch.int64, device=device)
+ send_payload = torch.cat(payload) if payload else torch.empty(0, dtype=dtype, device=device)
+
+ recv_meta = _all_to_all_variable(send_meta, recv_meta_splits, send_meta_splits)
+ recv_entries = [[] for _ in range(ws)]
+ meta_offset = 0
+ for src, count in enumerate(recv_meta_splits):
+ if count == 0:
+ continue
+ if count % _FSDP_HEADER_WIDTH:
+ raise RuntimeError(f"Malformed FSDP metadata split: {count}")
+ rows = recv_meta[meta_offset : meta_offset + count].view(-1, _FSDP_HEADER_WIDTH).cpu().tolist()
+ meta_offset += count
+ for param_idx, offset, length, got in rows:
+ if got != code:
+ raise RuntimeError(f"FSDP dtype mismatch for bucket {dtype}: expected {code}, got {got}")
+ recv_entries[src].append((param_idx, offset, length))
+
+ recv_payload = _all_to_all_variable(send_payload, recv_payload_splits, send_payload_splits)
+ payload_offset = 0
+ for src_entries in recv_entries:
+ for param_idx, offset, length in src_entries:
+ chunk = recv_payload[payload_offset : payload_offset + length]
+ received.setdefault(param_idx, []).append((offset, chunk))
+ payload_offset += length
+ if payload_offset != recv_payload.numel():
+ raise RuntimeError("FSDP payload unpack mismatch")
+
+ if keep_state:
+ states.append(_FSDPBucket(device, dtype, by_dst, send_payload_splits, recv_entries, recv_payload_splits))
+
+ return received, states
+
+
+def _reshape_fsdp_params(items):
+ rank = torch.distributed.get_rank()
+ schedule, bucket_lookup = _fsdp_bucket_schedule(items)
+ params, buckets = _exchange_fsdp_shards(schedule, bucket_lookup, items, lambda _, shard: shard, keep_state=True)
+ grads, _ = _exchange_fsdp_shards(schedule, bucket_lookup, items, lambda p, _: p.grad)
+
+ for p, info, shard in items:
+ p.grad = None
+ if info.owner != rank:
+ continue
+
+ pieces = params.get(info.param_idx, ())
+ total = sum(chunk.numel() for _, chunk in pieces)
+ if total != info.total:
+ raise RuntimeError(f"FSDP parameter assembly mismatch for param {info.param_idx}: {total} != {info.total}")
+
+ full = shard.new_empty(info.total)
+ for offset, chunk in pieces:
+ full[offset : offset + chunk.numel()].copy_(chunk)
+ p.data = full.view(info.orig_shape)
+
+ grad_pieces = grads.get(info.param_idx, ())
+ if not grad_pieces:
+ continue
+ grad_total = sum(chunk.numel() for _, chunk in grad_pieces)
+ if grad_total != info.total:
+ raise RuntimeError(f"FSDP grad assembly mismatch for param {info.param_idx}: {grad_total} != {info.total}")
+ grad = full.new_empty(info.total, dtype=grad_pieces[0][1].dtype)
+ for offset, chunk in grad_pieces:
+ grad[offset : offset + chunk.numel()].copy_(chunk)
+ p.grad = grad.view(info.orig_shape)
+
+ return _FSDPState(items, buckets)
+
+
+def _restore_fsdp_params(state):
+ by_param = {info.param_idx: (p, info, shard) for p, info, shard in state.items}
+ for bucket in state.buckets:
+ payload = []
+ for dst, recv_entries in enumerate(bucket.recv_entries):
+ for param_idx, offset, length in recv_entries:
+ p, info, _ = by_param[param_idx]
+ flat = p.data.reshape(-1)
+ if flat.numel() != info.total:
+ raise RuntimeError(f"FSDP return path expects full param {param_idx}, got {flat.numel()}")
+ payload.append(flat[offset : offset + length])
+ send_payload = torch.cat(payload) if payload else torch.empty(0, dtype=bucket.dtype, device=bucket.device)
+ recv_payload = _all_to_all_variable(send_payload, bucket.send_splits, bucket.recv_splits)
+
+ payload_offset = 0
+ for send_entries in bucket.send_entries:
+ for _, _, _, flat, shard in send_entries:
+ shard.copy_(recv_payload[payload_offset : payload_offset + flat.numel()])
+ payload_offset += flat.numel()
+ if payload_offset != recv_payload.numel():
+ raise RuntimeError("FSDP return payload unpack mismatch")
+
+ for p, _, shard in state.items:
+ p.data = shard
+ p.grad = None
def _view_param(p, shape):
@@ -1615,57 +1797,55 @@ def _view_param(p, shape):
def _reshape_params(params, orig_shapes, gather=True):
if not orig_shapes:
return [], []
- rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ dist_ready = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
views, gathers = [], []
for p in params:
info = orig_shapes.get(id(p))
- if info is None or p.data.shape == info.orig_shape:
+ if info is None:
continue
- if info.group is not None and gather:
+ if gather and dist_ready and info.owner is not None:
shard = p.data
- full = _reduce_gather(shard, info.offset, info.total, info.group, info.owner)
- if rank == info.owner:
- p.data = full.view(info.orig_shape)
- if p.grad is not None:
- p.grad = _reduce_gather(p.grad, info.offset, info.total, info.group, info.owner).view(
- info.orig_shape
- )
- else:
- del full
- if p.grad is not None:
- _reduce_gather(p.grad, info.offset, info.total, info.group, info.owner)
- p.grad = None
gathers.append((p, info, shard))
+ continue
+
+ if p.data.shape == info.orig_shape:
+ continue
+
+ orig, numel = info.orig_shape, p.data.numel()
+ if numel == info.total:
+ target = orig
+ elif numel > 0 and len(orig) >= 2:
+ inner = math.prod(orig[1:])
+ target = (numel // inner, *orig[1:]) if numel % inner == 0 else None
else:
- orig, numel = info.orig_shape, p.data.numel()
- if numel == info.total:
- target = orig
- elif numel > 0 and len(orig) >= 2:
- inner = math.prod(orig[1:])
- target = (numel // inner, *orig[1:]) if numel % inner == 0 else None
- else:
- continue
- if target is not None:
- flat = p.data.shape
- _view_param(p, target)
- views.append((p, flat))
+ continue
+ if target is not None:
+ flat = p.data.shape
+ _view_param(p, target)
+ views.append((p, flat))
+
+ if gathers:
+ gathers = _reshape_fsdp_params(gathers)
return views, gathers
def _restore_params(views, gathers):
- rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
- for p, info, shard in gathers:
- if rank == info.owner:
- full = p.data.flatten()
- else:
- full = shard.new_empty(info.total)
- torch.distributed.broadcast(full, src=info.owner, group=info.group)
- shard.copy_(full[info.offset : info.offset + shard.numel()])
- p.data = shard
- p.grad = None
+ if isinstance(gathers, _FSDPState):
+ _restore_fsdp_params(gathers)
+ else:
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ for p, info, shard in gathers:
+ if rank == info.owner:
+ full = p.data.flatten()
+ else:
+ full = shard.new_empty(info.total)
+ torch.distributed.broadcast(full, src=info.owner, group=info.group)
+ shard.copy_(full[info.offset : info.offset + shard.numel()])
+ p.data = shard
+ p.grad = None
for p, flat in views:
_view_param(p, flat)
@@ -1868,7 +2048,7 @@ def _find_val_name(self, name):
def _step(self, group):
if "base_lr" not in group:
group["base_lr"] = group["lr"]
- if "prev_lr" in group and group["prev_lr"] != group["lr"]:
+ if "base_lr" in group and group["base_lr"] != group["lr"]:
utils.warn_once(
f"Learning rate changed between steps. This is an experimental feature and "
f"only supported with multi_tensor=True (currently multi_tensor={group['multi_tensor']})."
@@ -1911,10 +2091,10 @@ def _step_inner(self, group):
if isinstance(step, torch.Tensor):
step = step.to(device=p[0].device, dtype=torch.int64)
else:
- step = torch.tensor(step, dtype=torch.int64, device=p[0].device)
+ step = utils.scalar_guard(step, p[0])
group["_group_step"] = group["step"] = step = step + 1
self.state_(p[0])["step"] = step
- group["prev_lr"] = group["lr"] = group["base_lr"] * step / max(step, group["warmup_steps"] + 1)
+ group["prev_lr"] = group["lr"] = group["base_lr"] * step / step.clamp(min=group["warmup_steps"] + 1)
if not group["multi_tensor"] or len(p) == 1:
for param, grad in zip(p, g):
@@ -1923,7 +2103,7 @@ def _step_inner(self, group):
self._chain(group, g, p, caution)
group["caution"] = caution
- group["lr"] = group["prev_lr"]
+ group["lr"] = group["base_lr"]
group["step"] = None
def _run_chain(self, state, group, g, p, caution):
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 76e6da3..e9c6791 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -1269,6 +1269,28 @@ def hasattr_none(obj, name):
return getattr(obj, name, None) is not None
+def set_temporary(group: dict, tensor: Tensor, **kwargs):
+ if not kwargs:
+ return
+ state = group.setdefault("_tmp", {}).setdefault(id(tensor), {"tensor": tensor})
+ state.update(kwargs)
+
+
+def get_temporary(group: dict, tensor: Tensor):
+ tmp = group.get("_tmp")
+ return None if tmp is None else tmp.get(id(tensor))
+
+
+def take_temporary(group: dict, tensor: Tensor, *keys):
+ state = get_temporary(group, tensor)
+ if state is None:
+ return None if len(keys) == 1 else (None,) * len(keys)
+ out = tuple(state.pop(key, None) for key in keys)
+ if len(state) == 1:
+ group["_tmp"].pop(id(tensor), None)
+ return out[0] if len(keys) == 1 else out
+
+
class ExactHVPFailed(ValueError):
pass
@@ -1385,21 +1407,29 @@ def split_p_and_g_in_group(
should_promote: bool = True,
raw: bool = False,
):
+ tmp = group.get("_tmp")
for p in group["params"]:
- grad = getattr(p, "grad", None)
+ if raw:
+ grad = getattr(p, "grad", None)
+ if grad is None and skip_none:
+ continue
+ p.grad = None
+ yield p, grad
+ continue
+
+ state = None if tmp is None else tmp.get(id(p))
+ grad = None if state is None else state.pop("grad", None)
+ if grad is None:
+ grad = getattr(p, "grad", None)
if grad is None and skip_none:
continue
p.grad = None
- if raw:
- yield p, grad
- continue
-
if group.get("merge_dims", False) and not p.data.is_contiguous():
for fmt in (torch.channels_last, torch.channels_last_3d):
if p.data.is_contiguous(memory_format=fmt):
- p._restore_memory_format = fmt
+ set_temporary(group, p, restore_memory_format=fmt)
break
p.data = p.data.contiguous()
@@ -1407,20 +1437,23 @@ def split_p_and_g_in_group(
for i, pv in enumerate(p_views):
self.mapping_inverse[_tensor_key(pv)] = (p, i)
- vector = getattr(p, "vector", None)
- hessian_vector = getattr(p, "hessian_vector", None)
- p.vector = None
- p.hessian_vector = None
-
- grad, vs, hvs = [
- [None] * len(p_views) if x is None else merge_group(group, x) #
- for x in (grad, vector, hessian_vector)
- ]
+ if state is None:
+ vector = hessian_vector = None
+ else:
+ vector = state.pop("vector", None)
+ hessian_vector = state.pop("hessian_vector", None)
+ if len(state) == 1:
+ tmp.pop(id(p), None)
+ grad = itertools.repeat(None, len(p_views)) if grad is None else merge_group(group, grad)
+ vs = itertools.repeat(None, len(p_views)) if vector is None else merge_group(group, vector)
+ hvs = itertools.repeat(None, len(p_views)) if hessian_vector is None else merge_group(group, hessian_vector)
for pv, g, v, hv in zip(p_views, grad, vs, hvs):
g = promote_detach(g, should_promote)
- pv.vector = promote_detach(v, should_promote)
- pv.hessian_vector = promote_detach(hv, should_promote)
+ v = promote_detach(v, should_promote)
+ hv = promote_detach(hv, should_promote)
+ if v is not None or hv is not None:
+ set_temporary(group, pv, vector=v, hessian_vector=hv)
yield pv, g
def state_size(self) -> int:
@@ -1494,10 +1527,10 @@ def _finite_differences_hvp(self, closure):
for group in self.param_groups:
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
grads.append(g)
- p.vector = torch.randn_like(p)
- p.orig = p.data.clone()
+ vector = torch.randn_like(p)
+ set_temporary(group, p, vector=vector, orig=p.data.clone())
# scale taken from https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L2161
- stochastic_add_(p.data, p.vector, torch.finfo(torch.float32).eps ** 0.5)
+ stochastic_add_(p.data, vector, torch.finfo(torch.float32).eps ** 0.5)
with torch.enable_grad():
closure()
@@ -1506,22 +1539,22 @@ def _finite_differences_hvp(self, closure):
# this costs more memory, but the imprecision seems too severe to use the other method
for group in self.param_groups:
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
- p.grad = grads.pop(0)
- stochastic_add_divide_(g, p.grad, -1, torch.finfo(torch.float32).eps ** 0.5)
- p.hessian_vector = g
- p.data.copy_(p.orig)
- del p.orig
+ grad = grads.pop(0)
+ set_temporary(group, p, grad=grad, hessian_vector=g)
+ stochastic_add_divide_(g, grad, -1, torch.finfo(torch.float32).eps ** 0.5)
+ p.data.copy_(take_temporary(group, p, "orig"))
return loss
def _double_backward_hvp(self, closure):
with torch.enable_grad(), patch_backward():
loss = closure()
- params, grads = [], []
+ params, grads, groups = [], [], []
for group in self.param_groups:
for p, g in self.split_p_and_g_in_group(group, skip_none=True, raw=True):
params.append(p)
grads.append(g)
+ groups.append(group)
if not params:
raise ValueError("No parameter has gradients")
@@ -1534,10 +1567,8 @@ def _double_backward_hvp(self, closure):
raise ExactHVPFailed(str(e.args))
unused = []
- for p, g, v, hv in zip(params, grads, vs, hvs):
- p.hessian_vector = detach(hv)
- p.grad = detach(g)
- p.vector = detach(v)
+ for group, p, g, v, hv in zip(groups, params, grads, vs, hvs):
+ set_temporary(group, p, grad=detach(g), vector=detach(v), hessian_vector=detach(hv))
if hv is None:
unused.append(list(p.shape))
@@ -1591,17 +1622,18 @@ def _handle_closure(self, closure):
return self._handle_closure(closure)
def _cleanup_temporary_tensors(self):
- for real, views in self.mapping.items():
- fmt = getattr(real, "_restore_memory_format", None)
- if fmt is not None:
- del self.mapping_inverse[_tensor_key(real)]
- real.data = real.data.to(memory_format=fmt)
- self.mapping_inverse[_tensor_key(real)] = (real, 0)
- del real._restore_memory_format
- for tensor in (real, *views):
- for key in ("grad", "vector", "hessian_vector", "orig"):
- if hasattr(tensor, key):
- setattr(tensor, key, None)
+ for group in self.param_groups:
+ tmp = group.pop("_tmp", None)
+ if tmp is None:
+ continue
+ for state in tmp.values():
+ fmt = state.get("restore_memory_format")
+ if fmt is None:
+ continue
+ tensor = state["tensor"]
+ self.mapping_inverse.pop(_tensor_key(tensor), None)
+ tensor.data = tensor.data.to(memory_format=fmt)
+ self.mapping_inverse[_tensor_key(tensor)] = (tensor, 0)
def step(self, closure: Optional[Callable] = None):
if self.precond_schedule is None:
@@ -2172,10 +2204,10 @@ def update_param_(
grad = [None] * len(param)
_compilable_update_(param, update, decay, lr, caution, grad)
-
-def precond_schedule(step, precond_scheduler):
- precond_prob = max(step, 1) ** precond_scheduler[0]
- precond_prob = math.log10(precond_prob)
+@decorator_knowngood
+def precond_schedule(step: Tensor, precond_scheduler):
+ precond_prob = step.clamp(min=1) ** precond_scheduler[0]
+ precond_prob = torch.log10(precond_prob)
precond_prob = precond_prob ** precond_scheduler[1] + 1
return 1 / precond_prob
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 1663f13..d96d9d1 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -50,6 +50,8 @@
if n in REPRESENTATIVE_OPTS and n not in _FSDP_SKIP
]
+_OWNER_EDGE_OPTS = [n for n in ["Muon"] if n in REPRESENTATIVE_OPTS and n not in _FSDP_SKIP]
+
def _set_cache(cache_dir, compile_mode="default"):
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
@@ -101,6 +103,14 @@ def _make_integration_model():
)
+def _make_owner_edge_model():
+ return nn.Sequential(
+ nn.Linear(32, 32, bias=False),
+ nn.ReLU(),
+ nn.Linear(32, 1, bias=False),
+ )
+
+
def _make_data(dim=32, n=4):
torch.manual_seed(_DATA_SEED)
return [torch.randn(4, dim, device="cuda") for _ in range(n)]
@@ -118,6 +128,10 @@ def _make_integration_data():
return _make_data(64, 8)
+def _make_owner_edge_data():
+ return _make_data(32, 6)
+
+
def _train(model, opt, data):
for x in data:
model(x).mean().backward()
@@ -263,7 +277,7 @@ def _assert_close(ref, result, label, rtol=0, atol=0):
assert False, f"{label}: param {i} diverged beyond 1 ULP ({n} elements, max |diff|={worst:.2e})"
-def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label):
+def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label, world_size=2, tol=None):
cache_dir = tempfile.mkdtemp(prefix=f"hb_{label}_{opt_name}_")
ref_path = str(tmp_path / "ref.pt")
mp.spawn(_ref_worker, args=(opt_name, ref_path, cache_dir, None, model_fn, data_fn), nprocs=1, join=True)
@@ -271,12 +285,14 @@ def _run_fsdp_test(opt_name, tmp_path, model_fn, data_fn, label):
result_path = str(tmp_path / "result.pt")
mp.spawn(
_fsdp_worker,
- args=(2, str(tmp_path / "store"), opt_name, result_path, cache_dir, None, model_fn, data_fn),
- nprocs=2,
+ args=(world_size, str(tmp_path / "store"), opt_name, result_path, cache_dir, None, model_fn, data_fn),
+ nprocs=world_size,
join=True,
)
- tol = dict(rtol=1e-2, atol=1e-4) if opt_name in _FSDP_PSGD else {}
- _assert_close(ref, torch.load(result_path, weights_only=True), f"{label}/{opt_name}", **tol)
+ base_tol = dict(rtol=1e-2, atol=1e-4) if opt_name in _FSDP_PSGD else {}
+ if tol is not None:
+ base_tol.update({k: max(base_tol.get(k, 0), v) for k, v in tol.items()})
+ _assert_close(ref, torch.load(result_path, weights_only=True), f"{label}/{opt_name}", **base_tol)
@pytest.mark.parametrize("opt_name", REPRESENTATIVE_OPTS)
@@ -336,3 +352,16 @@ def test_fsdp_misaligned(opt_name, tmp_path):
@pytest.mark.parametrize("opt_name", _INTEGRATION_OPTS)
def test_fsdp_integration(opt_name, tmp_path):
_run_fsdp_test(opt_name, tmp_path, _make_integration_model, _make_integration_data, "FSDP-integ")
+
+
+@pytest.mark.parametrize("opt_name", _OWNER_EDGE_OPTS)
+def test_fsdp_owner_edge(opt_name, tmp_path):
+ _run_fsdp_test(
+ opt_name,
+ tmp_path,
+ _make_owner_edge_model,
+ _make_owner_edge_data,
+ "FSDP-owner3",
+ world_size=3,
+ tol=dict(atol=5e-8),
+ )
From 94a5974f25b8c2947f679a04c41adfe0a6efd7a5 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Mon, 30 Mar 2026 18:59:16 +0000
Subject: [PATCH 12/24] ruff
---
heavyball/chainable.py | 10 ++++++++--
heavyball/utils.py | 1 +
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index 5204dca..5b79a31 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -1670,7 +1670,9 @@ def _exchange_fsdp_shards(schedule, bucket_lookup, items, tensor_getter, keep_st
bucket_idx = bucket_lookup[info.param_idx]
device, dtype = schedule[bucket_idx]
if flat.device != device or flat.dtype != dtype:
- raise RuntimeError(f"FSDP bucket mismatch for param {info.param_idx}: expected {(device, dtype)}, got {(flat.device, flat.dtype)}")
+ raise RuntimeError(
+ f"FSDP bucket mismatch for param {info.param_idx}: expected {(device, dtype)}, got {(flat.device, flat.dtype)}"
+ )
per_bucket[bucket_idx].append((info.owner, info.param_idx, info.offset, flat, shard))
received, states = {}, []
@@ -1692,7 +1694,11 @@ def _exchange_fsdp_shards(schedule, bucket_lookup, items, tensor_getter, keep_st
for value in (param_idx, offset, flat.numel(), code)
]
payload = [flat for dst_entries in by_dst for _, _, _, flat, _ in dst_entries]
- send_meta = torch.tensor(meta, dtype=torch.int64, device=device) if meta else torch.empty(0, dtype=torch.int64, device=device)
+ send_meta = (
+ torch.tensor(meta, dtype=torch.int64, device=device)
+ if meta
+ else torch.empty(0, dtype=torch.int64, device=device)
+ )
send_payload = torch.cat(payload) if payload else torch.empty(0, dtype=dtype, device=device)
recv_meta = _all_to_all_variable(send_meta, recv_meta_splits, send_meta_splits)
diff --git a/heavyball/utils.py b/heavyball/utils.py
index e9c6791..8ecad17 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -2204,6 +2204,7 @@ def update_param_(
grad = [None] * len(param)
_compilable_update_(param, update, decay, lr, caution, grad)
+
@decorator_knowngood
def precond_schedule(step: Tensor, precond_scheduler):
precond_prob = step.clamp(min=1) ** precond_scheduler[0]
From c2676bd678296ee9501204cabb49b0d0634e0f05 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Tue, 31 Mar 2026 10:28:36 +0000
Subject: [PATCH 13/24] add optimizer step benchmark
---
benchmarks/bench_optimizer_step.py | 67 ++++++++++++++++++++++++++++++
heavyball/utils.py | 7 +++-
2 files changed, 72 insertions(+), 2 deletions(-)
create mode 100644 benchmarks/bench_optimizer_step.py
diff --git a/benchmarks/bench_optimizer_step.py b/benchmarks/bench_optimizer_step.py
new file mode 100644
index 0000000..06d2c73
--- /dev/null
+++ b/benchmarks/bench_optimizer_step.py
@@ -0,0 +1,67 @@
+import cProfile
+import pstats
+from enum import StrEnum
+from math import prod
+from time import perf_counter
+import numpy as np
+import torch
+import typer
+
+import heavyball
+
+app = typer.Typer(add_completion=False, pretty_exceptions_enable=False)
+
+DEFAULT_SHAPES = ((128, 128),) * 4 + ((512, 128),) * 2 + ((512,),) * 2 + ((128,),) * 4
+
+
+class DType(StrEnum):
+ float16 = "float16"
+ bfloat16 = "bfloat16"
+ float32 = "float32"
+
+
+def parse_shape(text: str) -> tuple[int, ...]:
+ try:
+ shape = tuple(map(int, text.lower().replace("x", " ").split()))
+ except ValueError as e:
+ raise typer.BadParameter(f"invalid shape: {text!r}") from e
+ if not shape:
+ raise typer.BadParameter(f"invalid shape: {text!r}")
+ return shape
+
+
+@app.command()
+def main(optimizer: str = "AdamW", dtype: DType = DType.float32, shape: list[str] | None = None,
+ compile_step: bool = False, update_precond: bool | None = None, steps: int = 300, warmup: int = 20,
+ windows: int = 6, seed: int = 0, ):
+ shapes = DEFAULT_SHAPES if shape is None else tuple(map(parse_shape, shape))
+ torch_dtype = getattr(torch, dtype)
+ kwargs = {"compile_step": compile_step}
+ if update_precond is not None:
+ kwargs["preconditioner_update_probability"] = float(update_precond)
+
+ gen = torch.Generator(device="cuda").manual_seed(seed)
+ params = []
+ for dims in shapes:
+ param = torch.nn.Parameter(torch.randn(dims, device="cuda", dtype=torch_dtype, generator=gen))
+ param.grad = torch.randn(dims, device="cuda", dtype=torch_dtype, generator=gen)
+ params.append(param)
+
+ step = getattr(heavyball, optimizer)(params, **kwargs).step
+ for _ in range(warmup):
+ step()
+
+ times = []
+ for _ in range(windows):
+ torch.cuda.synchronize()
+ start = perf_counter()
+ for _ in range(steps):
+ step()
+ torch.cuda.synchronize()
+ times.append((perf_counter() - start) / steps)
+
+ print(f"{len(shapes)} tensors, {sum(prod(s) for s in shapes)} total params")
+ print(f'Median Time: {np.median(times) * 1e6:.3f}µs')
+
+if __name__ == "__main__":
+ app()
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 8ecad17..5da9d51 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -1318,6 +1318,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
fallback_to_finite_differences: bool = True
_fallback_enabled: bool = False
hvp_interval: int = 1 # grad is faster initially, hvp later
+ auto_set_grad_to_none: bool = False
_INSTANCE_ATTRS = (
"compile_step",
@@ -1413,7 +1414,8 @@ def split_p_and_g_in_group(
grad = getattr(p, "grad", None)
if grad is None and skip_none:
continue
- p.grad = None
+ if self.auto_set_grad_to_none:
+ p.grad = None
yield p, grad
continue
@@ -1424,7 +1426,8 @@ def split_p_and_g_in_group(
if grad is None and skip_none:
continue
- p.grad = None
+ if self.auto_set_grad_to_none:
+ p.grad = None
if group.get("merge_dims", False) and not p.data.is_contiguous():
for fmt in (torch.channels_last, torch.channels_last_3d):
From 6afa484747a8e3bc9e8cc969349b86cd795f90dc Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Tue, 31 Mar 2026 16:07:49 +0000
Subject: [PATCH 14/24] run ortho in scion, round ns iteration
---
heavyball/utils.py | 28 ++++++++++++++--------------
1 file changed, 14 insertions(+), 14 deletions(-)
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 5da9d51..099da96 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -656,7 +656,7 @@ def _scion_bias_rms_direction(x: Tensor, eps: float = 1e-8) -> Tensor:
def _scion_spectral_direction(x: Tensor) -> Tensor:
flat = x.reshape(x.shape[0], -1)
- inplace_orthogonal_(flat)
+ flat = inplace_orthogonal_(flat)
normalized = flat.reshape_as(x)
in_dim = max(flat.shape[1], 1)
scale = math.sqrt(x.shape[0] / in_dim)
@@ -665,7 +665,7 @@ def _scion_spectral_direction(x: Tensor) -> Tensor:
def _scion_spectral_conv_direction(x: Tensor) -> Tensor:
flat = x.reshape(x.shape[0], -1)
- inplace_orthogonal_(flat)
+ flat = inplace_orthogonal_(flat)
normalized = flat.reshape_as(x)
out_channels, in_channels = x.shape[:2]
spatial = math.prod(x.shape[2:]) if x.ndim > 2 else 1
@@ -774,7 +774,7 @@ def _compilable_grafting(magnitude, direction):
def _compilable_orthogonal_(x: Tensor, mode: str | ZerothPowerMode, out: Tensor | None, scale_mode: str):
if not isinstance(mode, ZerothPowerMode):
mode = ZerothPowerMode(mode)
- if not isinstance(scale_mode, ZerothPowerMode):
+ if not isinstance(scale_mode, OrthoScaleMode):
scale_mode = OrthoScaleMode(scale_mode)
if mode == ZerothPowerMode.newtonschulz or x.shape[0] != x.shape[1]:
y = zeropower_via_newtonschulz5(x, 5)
@@ -2147,17 +2147,17 @@ def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
def stochastic_round_(ref: Tensor, source: Tensor | None = None):
if source is None:
source = ref
- if ref.dtype != torch.bfloat16:
- return source.to(ref.dtype)
- if source.dtype == torch.bfloat16:
- return source
- if source.dtype in (torch.float16, torch.float32, torch.float64):
- source = source.to(torch.float32).view(dtype=torch.int32)
- noise = sum(torch.randint_like(source, low=0, high=(1 << 16)) for _ in range(dither_steps))
- noise = noise + source - (dither_steps - 1) * (1 << 15) # center | x - (N-1)*delta/2
- noise = noise.bitwise_and(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
- return noise.view(dtype=torch.float32).bfloat16()
- return source.to(ref.dtype)
+ dtype = torch.bfloat16
+ else:
+ dtype = ref.dtype
+ if dtype != torch.bfloat16 or source.dtype not in (torch.float16, torch.float32, torch.float64):
+ return source.to(dtype)
+
+ source = source.to(torch.float32).view(dtype=torch.int32)
+ noise = sum(torch.randint_like(source, low=0, high=(1 << 16)) for _ in range(dither_steps))
+ noise = noise + source - (dither_steps - 1) * (1 << 15) # center | x - (N-1)*delta/2
+ noise = noise.bitwise_and(-65536) # FFFF0000 mask, preserves sign+exp+7 mantissa bits
+ return noise.view(dtype=torch.float32).bfloat16()
@decorator_knowngood
From e387c8af051b41aed053da90f96f5e5fe7e6f360 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Tue, 31 Mar 2026 19:25:22 +0000
Subject: [PATCH 15/24] compile more
---
benchmarks/bench_optimizer_step.py | 36 +++++++++++++++++++++---------
heavyball/chainable.py | 1 -
heavyball/utils.py | 9 ++++----
3 files changed, 30 insertions(+), 16 deletions(-)
diff --git a/benchmarks/bench_optimizer_step.py b/benchmarks/bench_optimizer_step.py
index 06d2c73..07ca682 100644
--- a/benchmarks/bench_optimizer_step.py
+++ b/benchmarks/bench_optimizer_step.py
@@ -1,5 +1,3 @@
-import cProfile
-import pstats
from enum import StrEnum
from math import prod
from time import perf_counter
@@ -11,7 +9,7 @@
app = typer.Typer(add_completion=False, pretty_exceptions_enable=False)
-DEFAULT_SHAPES = ((128, 128),) * 4 + ((512, 128),) * 2 + ((512,),) * 2 + ((128,),) * 4
+DEFAULT_SHAPES = ((2048, 2048),) * 32
class DType(StrEnum):
@@ -20,6 +18,11 @@ class DType(StrEnum):
float32 = "float32"
+class Library(StrEnum):
+ heavyball = "heavyball"
+ torch = "torch"
+
+
def parse_shape(text: str) -> tuple[int, ...]:
try:
shape = tuple(map(int, text.lower().replace("x", " ").split()))
@@ -31,13 +34,25 @@ def parse_shape(text: str) -> tuple[int, ...]:
@app.command()
-def main(optimizer: str = "AdamW", dtype: DType = DType.float32, shape: list[str] | None = None,
- compile_step: bool = False, update_precond: bool | None = None, steps: int = 300, warmup: int = 20,
- windows: int = 6, seed: int = 0, ):
+def main(
+ optimizer: str = "AdamW",
+ library: Library = Library.heavyball,
+ dtype: DType = DType.float32,
+ shape: list[str] | None = None,
+ compile_step: bool = False,
+ fused: bool | None = None,
+ update_precond: bool | None = None,
+ steps: int = 300,
+ warmup: int = 20,
+ windows: int = 6,
+ seed: int = 0,
+):
shapes = DEFAULT_SHAPES if shape is None else tuple(map(parse_shape, shape))
torch_dtype = getattr(torch, dtype)
- kwargs = {"compile_step": compile_step}
- if update_precond is not None:
+ kwargs = {"compile_step": compile_step} if library is Library.heavyball else {}
+ if fused is not None and library is Library.torch:
+ kwargs["fused"] = fused
+ if update_precond is not None and library is Library.heavyball:
kwargs["preconditioner_update_probability"] = float(update_precond)
gen = torch.Generator(device="cuda").manual_seed(seed)
@@ -47,7 +62,8 @@ def main(optimizer: str = "AdamW", dtype: DType = DType.float32, shape: list[str
param.grad = torch.randn(dims, device="cuda", dtype=torch_dtype, generator=gen)
params.append(param)
- step = getattr(heavyball, optimizer)(params, **kwargs).step
+ module = heavyball if library is Library.heavyball else torch.optim
+ step = getattr(module, optimizer)(params, **kwargs).step
for _ in range(warmup):
step()
@@ -61,7 +77,7 @@ def main(optimizer: str = "AdamW", dtype: DType = DType.float32, shape: list[str
times.append((perf_counter() - start) / steps)
print(f"{len(shapes)} tensors, {sum(prod(s) for s in shapes)} total params")
- print(f'Median Time: {np.median(times) * 1e6:.3f}µs')
+ print(f"Median Time: {np.median(times) * 1e6:.3f}µs")
if __name__ == "__main__":
app()
diff --git a/heavyball/chainable.py b/heavyball/chainable.py
index 5b79a31..e8d8b2e 100644
--- a/heavyball/chainable.py
+++ b/heavyball/chainable.py
@@ -1182,7 +1182,6 @@ def _update_psgd_precond(
Q,
group["store_triu_as_line"],
utils.get_beta2(group),
- group["ortho_method"],
vector,
running_lower_bound,
group["lower_bound_beta"],
diff --git a/heavyball/utils.py b/heavyball/utils.py
index 099da96..04fe270 100644
--- a/heavyball/utils.py
+++ b/heavyball/utils.py
@@ -2813,15 +2813,15 @@ def max_eigenvalue_spd(A_outer: Tensor, power_iter: int = 4) -> Tensor:
def _inner():
x = A_outer.index_select(0, max_idx).flatten().contiguous()
- A = stochastic_round_(A_outer / x_norm)
+ A = promote(A_outer) / x_norm
x = x / x_norm
def _mv(x):
- return promote((x.to(A.dtype) @ A.mT) @ A.mT)
+ return promote((x @ A.mT) @ A.mT)
for _ in range(power_iter):
x = F.normalize(_mv(x), dim=0)
- return (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
+ return (x @ _mv(x)).sqrt() * x_norm
return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone()).squeeze()
@@ -3006,14 +3006,13 @@ def _update_lb(ell: Tensor, lb_state: Tensor, beta: Tensor) -> Tensor:
return ell
-@decorator
+@functools.partial(decorator_knowngood, fullgraph=False)
def psgd_update_precond(
G: Tensor,
precond_lr: float,
oq: "TriuOrLine",
store_triu_as_line: bool,
beta2: float,
- ortho_method: Optional[str],
V: Tensor,
running_lower_bound: List[Tensor],
lower_bount_beta: float,
From 2c77d49d39f29695602f13fd44dd00408a8e2455 Mon Sep 17 00:00:00 2001
From: ClashLuke <39779310+ClashLuke@users.noreply.github.com>
Date: Wed, 1 Apr 2026 09:17:29 +0000
Subject: [PATCH 16/24] update benchmark matrix
---
assets/benchmark_matrix.png | Bin 2496725 -> 1160757 bytes
1 file changed, 0 insertions(+), 0 deletions(-)
diff --git a/assets/benchmark_matrix.png b/assets/benchmark_matrix.png
index 8300fcafd11de224451ad0f2c8894d883177bc79..26a7839c3aff36118840f9909658c104c666fc44 100644
GIT binary patch
literal 1160757
zcmeEOiC@m?`ffPJ3}b1NXrZJ+l2l5|phZP#L8XO4+T^88i^dqVXUSTMv}>hJ+VEEP
zlvG-IlQxyKYW-dJd#3q0pU?Rdj?XzW6MFl8pXYw=`?{|Cx*z|YdRp`5ESbZ>!7)#J
zo4O$f2cHoK$HF(i{EXiTKeO;F{-xxiVd7%se8|Op|FMG{JNCOAb#!)dw6k93cJSD7
zJLe;^Yh~80m07dwu#3yl<4V%fPXGLkwa&+ErH?yLw&GKMJ-W^GI0wgCQTqQ+9=4@#`{O^DAd*X`ig8%)GBrgAz
zIP<^%QS8NauK)c{ZW(jM&HwLz6sGgI3M=?ue*IyneNE}V{62nb=c-@-_fqVo`QJ|c
z6CwW#zv-p;-v>6Q3jqL^L@j{hZ!>81EzqL^NZovSWOcKkmM`D0#rgyM1Q
zx=cG6o6;bYY)8{ipFXKuT3Uv!-m~`k<<)%el6D13eYaJ-1<7tkM+V<>Q
z>x-7L-l9-WQB7vK+LVA%yk2m}+xW7|SY4a<_YPhN4_|uld6@qk-c{h
z1jM>)YHA)hbZBx>u4#sKNrGXlj`7{yQufu!{H|Tq!KJ}V%iQXn_y#-Qt-iuv>ymY(
zQOTf@TVMG??v(+h&V)O(Dq<>{r`#G{Ydks=<8?fPf=i`~#<{xdbM%et@9s&f_3TOs
zJR!2J=b3m^P<+(MyyCoh-FX=veXRDk@%rB5ePuGG
zdyS><#GuDmcaF`sPtVTu+~4=;&&BHxZ2OIW{ffwwqa6wO7=hz2Wf&}f@i~`f%JYB!
z{ypPZ>vNL_#}0n_{JHecS#zyN2Ra;*Q?P3SM~i=Ae0wIY)tKj6(;pLCKhW{^j(cwt
zW3(gw>VO1CvF3r!s>HtW{_?Wf!rluXJa}N&@c^|6Xz5wtVB%c)zYE&-AO@`q#lL1Jhq{aQxnJF)B(pt#H&{z_rSt
z@9UEVe3ILIb{u+njZ9zzA|B9=YqlAb$?b(tx-@=
zX!7oVwXu)MeP?gZk)~#`^H;C(v7doH_4TISuz`)!kNEqq+&`fC@k7Ga_Q=VW57cIEhikF2mX$3W97`{htBxhBFBrNZ
zYpdtd)1VmX*_p`oz9{pClB>P>@mF$;z7oE&k-ipPPl4%0y2Zo!!y++}O`>heLi%1v
z7ll2YoV#d^U!Zj1#+BQz?GzcC8RizqCmpiZB9Hmy{e4}}Flz1w}t_RPPh@nM;^_)Q+4u_37r
z#`N?2VYuxFHs#1dir`gfIJ))aHJLl=Z6^614D%DMmor>CD!FdXyp>6B+JK0_LC-M7
z`50)g5DihA9BF@l-!|F&!Mzx*865KKqrAUqk9S#3t=iQq+jOGsW^PAZxCrGcElEyc
zp8*2_-|@b_FZatej~qFo9r_P#3^V+J^X=mv9;$cNa&vRj4&~(JG>Y5qUnb<098<<2
zkwyD#<=v;D9cr+q=V5PQZd;_6^`k#$iK)!tNaq~=9PUh6A$EP{cuS~Sh?cNVW=xq$
zucdciNp4^1Qr+>R3e&HDYhbleYl$+6}s&9`woT95^#tAdtB$Le0aGVsgv
zTNVGj<;P8ZtU599_8N9Bx(;=y?7D3vzwJ%e-Np$!_Xo>S#j<~wq;W?xn00>ZeE++t
zd8<$(9(eRTXyZEfECoN$(yKcwceL{^x34fq*WNVC0_(1n9823BP#)CPlg+YLQP8a3
zlN!|e;%XUQ!+Pk;d&=YRGnleBN6$D1
zg^u@x^dIvKYEmt@eTOCrEZiIA*5-eGoMmZy|K7dmEobH!pn}!9)>yOz$-ee8y?>}=
zu5|uF6|bQn-qpKfGMjM#1RwU~IaH-uY-%|dinBY*w9Rg9qZSx7h%<$f0dk^
zEU`D+5ev6G$vBZK%b_NODzlPH`$9^Y4adt>B+odZq{Pz?<{CkUwfRtdN-`zUB*Pj<
zNGwEkv^{M2F5bZ0XYd{4OGa7g4uwXZJ6fT7`EGrz5(bJ?8S-Gnx58vocIc;9bk=1#
zxOCTLA@cwJwU7H_(YhJaU;VM83U6Wa>FEW3`%SK$VaLvo_6mA`eZ+V!8NVVmilUTX
zHtx;XPmbF{REBDEyR+*J%`Od$&f`dzq`WX&$YWWo9FlOTZRCl#otrz0YH}GZ9>~!)
z2wh#f#^O^mZN7eZq*1qY`bX}w5&c02k~csG3&u{YNKzU3a5^yGidGcx!N03$tmmOf
z@2}~%H-0UiX@hXkH&^ozUT&;Wth(lKPve~v{h#gL7h%5{retmPP?v*wE6c1VU@1jpmj8*^?
zJ$?RsRR^bn5^7E#D^Pl6YWd5TLG?!-N-a@15>c)AJWP^1r}?B;1h!npwX;gyK{I#>
z?~`*2LUAHuT^H$ue(!H9Yp=LtgIY*AG-2oMEshx_i!!hYnm3hP1FF1w8Xgo4zQ#TT
z@UGUxVY~wfXOMHG@#)K#+uyqX6U$wpKZt0x>Qa{Z8yVey{mLDUHPdK1-dixNClIHK
z<0aCYlW}d`{w===Ds9jV5#U#w8b7{0N_AqvX5iC3A^3+Dg30D#cXr_CSK(S+)hTz(
zvh2H@O0j`VyfUlaxhPoJcUc$2;u1#Okw|>d*e{GOz{=how
zIB=PaOZ%IVkLQG~SxkRI<&>A#0Q3PabMo`^(_c)$naBUcfXuiDyJ|w6avJ6F(!BGQ
z%C8yXS2%#j@X;9zTqLEVdu=qq;0_DvDUi`pTB5Hu!PEcwRiK#e6(*p%DDV)We1uOZ
zLG|Dza?x(xbp}X@wRZ70t=n$8hjhNX%Xi{zrFBUFugrm`=byfQ9YU)ar+0G+Ku{>a
z7yH$cO=OUDd4Z+_1>GBD2sGp6HGs;)2HH#PpIsEIKxG)|yqji?3@UT{^GlOF7dxUE
zNdWWo65Tn?szbFlp@$puvVQ-tqXF?_-TvlwA4*;+JvXM&`m$(%6_
z>=->U#da+UvrUz{<|UanCM^PYnr}-8-O-M^gX(t@RG7
zl>y3MbEMX*aD~vD?oEm)h>HND#(m}G@Gp;+Ui*Cbzavbvg{F
zj(y>-$7Xw8x3kH_19+frvF?Pb20&8ZpW;f3|78BAHF}DzVazD)t5SOxQ5r|?oK35k
znw5$z%IGZ|E0rk|5KUGKUbNPHF&-)i(2E@!MdROe*O+BWP^eK%=v-b^vU{%+Z@uIF
z9TcP)NPk)Y13I2x{QRzMurtgbtllf*@%6*C^#^s5O)^BNip*QG!H-sRzUs)nG_Su?
z>@DM`Ps_{8r8|~S=Uk3I2Ga-`BT&SWkmI(t-~dT~izLj5+PB)ETW?c0524@ax;r6C
ztKxQq@7~u4F3P1P=xQ7t>Ff44(VLU}@WY>T33Z;SYZ<<>#*~X*dH}SpG*me1+vxg(
z&(*PKcYqW7a*8HWp7tld0)}n5yvNdC)n}9$tr2KwX@$4iIR5oX3j#3MDf|7(?*7)c
zHm#zQ6Rg;)mH{81oQu!QCZJSf=~a>#AGyS-E>nBVk&}~>BYsb6bhuX-l;cc~5x8cp
zZIlm-$*mtml|PtUlUNrB9v|?zUeFS`q-_PA)hSEQ3A&j?08$|X1#j4KO;y{fjgu44
z5s;G8q?9t&*4j$s_Ihq1mg&s=@hFs*e8f#3*k-b9L)J1R1KN->MwuWVirzlf*BS(y
z!V%=BExixr2|y?hRvAgAh0}MAumbx~fl9#s49y;5-?TtBSwMy%s8(_L!=rEVT)Wg=
z^_BaW!uEX9mIB}cj``m|hZE&iS0UOge$&N}n2qEt>mob|14(*ik_ujmn1xP*my@Mu
zN2~{;jl1T`=Y`K=idrbkHGwZ#U;kJp>>DJc;>kqa5}M1mc8yCEVYi=V3oI0m@;>$@
zwYR|Lgepopv*_HrY$W~M
zZ_if;wobHV+P!6|O?}t*Mq$lBE?0LTy<=nlOW77Y*VxvC(M$TO{76DBr@H2PIqRJx
z*5oF(q=elKKpdJ>9*
zt0QtX`6v2`uCe}@Ie(jRkr}X&OJ|j2teiHFi}UI)W4EFQ9*%!|LHKZ_dkQBfG*ZQj
zFojW`){*S@hZr5PS4+WuGGFJpH=Pg{I`QekH8tOptR_xQU|rh=WRf$5dPrAt*z#|)h0di2P$=EHQjZq-G43@F-5seqB3IoRBISy7Nt*r!NDPK0vx7qJ_Zu5
zb^7OdEHKXP0$~qVg3|>H+b`^Q=`|fU#EH!HvAMnZja>&F;Xv>
zA8Wfv<~TijLCB-iv$ar@S+u5yUftk*&^vk&>VtIt*DcEUS-^376ND%Q)2Z#N)Rivz
z-WDZH8SuiBg}SF9J*}S{O!F;icyQdvOu?Z#*#mexS}RnDokLaMw4^_oA4nl?@XNv8
zrh>Ozag*p%w=&S3p;>1CL2el&b
zt_(1&$>qP~Zv++C!$4^(Ly*tOI{BqGgM_wV16pcgN-ZK50^9gl4U
zL-Y3I(UWNiZ(6W8DUYGyW5<(H3doT`Nt$|2?IOJ
z=5tq;4XQ1bvp0@7{#voGLQ6Qdz@0(dBf@v`dmm^bpVGF8;VWw`mMl;kKN^7)Rq`#*
zG>X(o%DG1K4W|l-_P*O=CAdIkU}y9n-KavPC+4CJe`DUm}cnCQ^!3Z=`H0!0_&=khOX3(I%OP4eG|GOv@ux=$j;5Kr{=GbWu;;c^;IbP(~kb
zdrfZ>FAy{71yUO#Tdpiyh`%CIo~h~$`X|}g9IE*F8f)oicT!cLEAKG35{TlzB5ft4
z(CA{sX3*FPOXT+mw>mpQ&>Qq}7}tT9B2Ik1t}*F2gw-p->x`aDdljj&>9xvm1NJU0
z8|XQg6|mZx>Bp(jd0&!W)!T!uDuNlrwKl7>GZ&V!rFyTkRHdsD?{JfsGi%x4yL|C!
zNnPj0e^4ysT&dRI2T8FPyK2rwiC%fQ;W!D0);Es4m38}4yAkz<_>$C>Ny^Fu!y-{m
zMFvk`DevfBU0bq9QxR#9gk2jPfipdQ#WNQ$fpU1kDD5KF5KACfPznXZqP;j1YGWyI
z*7fT2ak=Z@F(fLUowsy%!r%g+!|`6(xUHc%WcFXvl7vNEQk30|?9qQhjJjU~m>CfHrQnq}1{i(S{!;rPuG7$2x_?Dkv!Iyl5X4jD;gUGg;
z)?8y^n<<@BJyahbtt)Jg4zwO0>9gy&dFc{2$)Ib%v!y2F4?*Ay30#{e@%iQTOG8bd
z1V;cr5^H@IVDl{V@zjzj&z?T@6Ae|qcWvdgG_y_fKUBE+%3q7Oq=X~wD1)~#Udog)
zNsp!T{M(WZ3AI}nD(t3093r@Bk?!J}C>5x7eF#@x`)-uarJ<^$hi{0F-UZ<=!7B*e
zKlDoHMEBMODm8C
zMgcKe&MrM-;yiYFH$Ex=SaU^cGsu!*=M;jJr+)uD-)xYxk>R4Ln})Y~M>inw+aZJB
zPLST;Sxnm2tw-O-l(wdU=FIcuYqlA}0Oy-ae5Q8-~nnH;Q
zwWYGo3kA3jX(~3Wt@Jb65urDR0?c9K0eds
zRE`=d4X%5s@&0Qtgx}MOCJ!nBfT~oCNAXQIbhBx+{$38UJ0*KuPprp&S9C>%ee2%8D9$U2I
zCtds4%sMDX)MPIGqc>
zA2Rctcn^Jrg*Vst@)$YWS;A+veQfj3oSZMNNCnEi-Y9x@fpFvy*o2DbfJtwZsR9q3
zyLeUnmaL69~e7~?v6=Lqf`rO|IDv$
z4RRHX+{W}KGpe@dWw!INlkLqJEQimewnjoDs
z>N`ZK`mg|_LSNt9q%|=*$X%{@^g}W8r)P~1-Sv6K&B4dnAOoc~1ekP_D#{%ZO>d-H
zHixDp1(0A45lXc$f)XAW{>s$m#YFUhxE94n9MjR+ALvw{9lUbKjTK3+KvypO{?X45
z95PdV$l@xZmRC2uwQhUFMwuRTO9r#8>;M>tOAzwarf<;^W=`OVEfN$_@4_Vpa#2|s
zA5P6)grIs)05Pk^a`lCROO|^GsY3!<0#S~^8b!7&gZN_lY3AjCBrKYYeO1C>p)>kY
z0P1{p-Bi7;7YOSOMk=A`Zkd(tyoQ<`Xjki-M-BMw%5yCI{@DQelR4Niy5=T+T~T$A
zrLzjk8{ir`PlF+}W=}*-IEZ!Z#F7UxCO#&`XpvOcM-qoPHJg*S6ujQZ(x}k;`^NTL
zrz=(G{v8>=C+Q4H3r{Bvjz75EpNU;%Qbfugdb!Q3_wV>6X}<5n=>_a|(3`Cbf<;3%
zirCkr-W-~c-w@?9WaI8QMf+R#Z8VrTP|y$YmN(Mdtm&k33BbFzDb^azi`b0r5H80wQ=0JkIn1%;N7?n`7c
zDx;85*BF23QEZDq_M-?n(vZs>q-R8uSBF944$wZe@=TzWjsp_O;ItF8XIUajLZO9g
zLFng7lhi|nA?YKm+Bn5bZU{9bVNaT+Bg6?A#LHqCYjrYg_&?4S6(PdOURfNPH7{#p
zw}1wdxIEeIspvC0KxgYE@>{GbiGT5to~tJNLVc67;-E`r^uOLz2BJ3SOU*;~-dulb^?X9X4dSf1&gKYGKuL4UZu0D+-UQq%OtL$G0aF7{{dU{k
zdBUps#W!|zoUBHw9svRk&n*O^szC29?D5CItw_Mi04n}nvg_AdL!$E|Ih({K~vc{nOQi7Ja04k7v
zLl<&fFd8Y;n3KJhC!5}9krSZzoEq;toOw}cX=wwe
zLx&FO%;t^Ng`C1wl(s6QytC*ke#)lqOe2a@I=kfK*?v<=?&seYFQg_w;vOg;E%^Cs
zzTPj3I}Fh{gj;Xpa_S!@0Lgr6Yb!%bE!w%cun+va3=0-eHHDS&%5-j4b!%Xhk)VwS
z5T`yhd!?s#uS1t2Y<|xB#b2O0ba{_IXp0O)=Od=-J`M?6Cq`FBR904&lnzoq(2*4D
zp@UrtQD!MxH+puxF7=@C6mMcm(+#oT|N4=%eA<<2uGj+&E*D7Fa8Bu(AnIZO9h2>%
z0%u|)w#p`_(QS3Litb$mu?7k*ca80V155)gL0%*m253r)bZ8sITv`OA0zHB3VQ90=
z&9eo0AfY;>e24IS2ROjLtpN@vwo3xrsdat5?}X92l0d#P->I=?52(#tXDm@M6bN>@
zp4IXN7p7hfQx?x5pPK*eIojPHt%Qh^V+~3uDU;BdUh~IKGpPf!*LK@6L>H8;v|>wu
z|3_x0tEQ46Dl<@5$T@|ELz6`MB=$Vyx~#1xdK=YdvJg9FobGvnE)C*1~F1hlRYp
zWYDW~t4KGa_RU+leFA!lQ*Mbc>X+q5D)1CXw!UI;;>uoDrCpk
z?CyCXbb_}_o>aM$mPaz>*DQc?35XKA_LX}Mc9xL9SXx#+NdAVhxnBy*F0b0fULUA_
zZoxJ8&{Qd#xj^}R`{*@jTtIFIlwk$(AC{KFk`@uXt69AMA({@P=#hki1QeNTi}n%=
zi6s=RrL?R#=<@!LtKzp`=TaO0x|L$sf0(+w07VM^_M!MhrjG(Ij7R?SfAJp~RCz=l
zX27OqrLjS{HH3TyD%5-D5+4TrBF9;B)*+(@QG>nBYNS;$2a$zCz>0E>|GEu3Ld`~3*&B!*w6c&xs3}Y2j
zU8Nsc7RnZdI-f~BG}N%7GAc-^4?f=jcQ08YJOy+`rD2D%vv#=R@k6XSRDsk!ORI+L
zhGQ?7S&80iiAn@fh*LtT>XZ&8jD*ycLLk{tX<-*S4Q4jiJb_WBzD@StO)CrlwM#0F
z+?8bP53ueV9D4a|m)T}&*k3h2rj2%5&jKrFg!~5Nl1YfWWoP&-7o+2+`@kK!ky5Wdt=cZk7xCeY+hN!Xpq)Pp7Y9++V5p<&-g@cbbOkI>7#CvKW
zfKj+urxgsDwY1RTX2bVTVXJ*Y`?Yn;4V-rG-VJ+!+EdDVBjR2cqN1$!HIqVlbAKng58f>Pa;2_hDo8a;
zmRZGZzrGQ9${H0f6pHkQit;|JydPnWl+f}hH8nKk_wI{wooJ|zgQ!%2T4s{I?~!n!
zWujqhEwVm&1u6jkYGD$b^PAuFt#Cp{B&EBYOlL_KF9`+sN!Q_00r{b!uG;zUQi8
z?~Yf5DsFCa*E3gxnTk1kUW3cq9TmWjpeK0u$+2Exf=6(2C{%Qj+UhVq+PUt1^HSME
z+oRVFqHA1&PA@uZz4HK_3^uwm)6z*vP
z{04nG?;dkbAQ}^g9jHp<6TQBn(0x?aLc^wRIQQUdKzK-@YOj_1UuIR4?U4bDtD&c-
zcWQz>JYkQk#sF{I3-sVex&xKpf9%LA=qPv9Pvv1VOqIWXx-?6O%aE`nq$z7OR>;%0
zs8#_^QRI@s*EMZ%N7+Iok*4kAg+5f=$YOz`wmbQ_*gIejqcBDIv-X%B|07?bY(h}LzO#P$NzM#-Y7)ly_1%KiFy
zZjI?%I(=w(R)2qRY==Df4tXDIPFSG}X5D~zt|;^wgXcghy{)dPh!5R^I!OqR@gjqy
zZIIw6nPLa0)p(2Wl!@(*P$3mAtN8j%J8qo*UZt6n^q&;oWa#qmF=cmFpxTxkYu`!<
zgq8&LVd7(#lv$=99+GKldy|(E0i8&@=4ngKt##{lJ*O!l2gfZJ;sPwf$>`2I7W9k_
z)>xHE#z%0Wv`$rb6^eC(wP>4rA&aNULZYR1Bg%Tc(vyY0y{M#ST{hJIK^I57(+Edq
z#_1~GUSbEdKL*`|$YB_A@Fsyp1}y}pmm8$Uit;!hFLulUU(ZdRz)81;;DBrfOPBxW
zG1$$x-%VH|uBg!$h&H__#VY=L6AQ)IkJ^lp&}nl&-*AXVp{f8ToAIRP;!8zfe}b*g
zj|(BHQA+|hemcLEtrsIeHHmvblMIE?O-;$|%VOWL`eyPeNDem{f61;tLQx#Y6RluZ
z8FLs`np5R1V6r*I%Lbe4aT>YKJkV5+KK-xib1{$ZQ6HW(4%zeHhLvE48W|nOp%zbC
zJr7@HzQ(uQMUqfPZ(`QjLsOQUYaT*tq`vC;gguozKO1)~fW_qL%Yi8K@%MnPcYvuc
zT)v$3IuDk!b+K(sY{JD6m3A06=R{;ITl#Ink@TU~%X><}c(~pH(yI{KNWgi@comie
zHZj9j6S#@-LMn%6am?h{DCHJ|Zsaa&BxrwoqYKbf5oh|RHh12AmtY_od$fH;dH1Zj
zi&o5uv+PEzhHW~n7;B*hFy4F5r{b2u<=XoO?6sq8I<9;=UTTlhs6D}sP-yb+(yY`-
zhIai~ed;~X&Ys!|`Bj8>0ik9AA&>_xl8=*fkA>^5=;0Jfcc}ZD49f8%ldoS!6C^cP$3$XL7>RoWqz*U%0T{bq@b#^_Ykg9E0NCqH-4_?p-A%3{0gi4#_W67Y~r@8vbQdw#0V$EIpPJ_$D#Ln(QxFu2|)k
zI0-c3$@%B;>4H`?Bvr{2C#XzphXHgXCh)5e+|*js8zOc)zMLs-gwCC+`PL6mB{W$k~gh#@p1mt-MM@1o!Tf
zIEs?NJxG>h>cIHh^Q|#i2rBw-sHpB5euKD;=g=2^pr$~2g(J|jH@1cuStFqYqAqBv
zt0PcbzyY==I-i-tYYp$1uJPVCt_N)jqohpIm#uyH<=%{$t$Xew`;xS602_#P!$_h5WcdjjrA70fFK7y|!)nuBMHarW)1^j=FIPPC
zdcJ(?*%c?JEaAQgYZit_NB&6e$FpJc-6elTgU#OoNPkbzr5M|0mvf1|(8>K5WSgac
zPf!4_Ox5?4s449MfY%+LOaMu*K(mZlfOcni<=ACnss@dCLoVO5QuCm43Q!vYW;vC<
z8`!!MK=qv;+U`7FdtdXQ0eF`ye7d5^(^u^vkJAd?B1>fQn)l!I33z0i7r#E
zOQyf#NN=rkLNX=j<2zOa|Cg0*@2Q!ySJvIdTlsHw_DutQCedZFyrYTw_>$MP
zsC(?c9eyADxa&8`9ZD^nS^@?M>2I+WGL}$EXLM0-lM3t-VO=qfhaj{SInK@WUF6ja
zMXphx4cc<%PW&pnul4wJ7P@c2q&SfCY^1(h45!?N2+Op+1)B{7+d=Zgo@F+&{PqUfCKwjsO
zm5^MH@cc;jW~id*o$#NZH*y~@TLNw>1Td{6
z(M5u!`tm1a;)Yhb$RtOSa#QBv3YK(DDD|zNg>wt8`)b>fm+F{2e4nH>G?4{hX4EmZ
zf3dvo=W*WrM-cGgCc#KlioSkbv7Is`5xW8L{h0?XejWNgkO!92=!QJO%
z^&nD0$#3Gn<*qGeJjhsyY#X-|uxZtvQ~{9cJvt5~cR_yDG4?{7#0)sk0_e~+e+c!H
znI&e2gja#qZkAAC_DV9@5bdZa7wsrgfNfx1YqKjeZT6DtR12@0=NdHBy(b21B!|~N
zML+lIP~>#*mhAYUiS?C_DjC+aWcfJ(F5(c>`|1NYiRT&FZ9+z}4Ni3%^pJC{oA1VW
zBXRq8-n(~?g^tpiDz=nJ#^<6l7x%!9*(xNe6|&v928n?ADAw1l+SIlL@x?aV12kz|
zU;k_>-xV*>2W%jo)Cd+mN6rdqilid4M7mb{fb=@}B;&k8I=W!^3kwY@qM3d02xwtP
z?f{);z}Zrg`mkro4$zn@M+Zj-yW)+DaBPSe4S@EFh``-4ULS!prtA2i^@?2-!)Pbi
z1eumwsQy~Gy^l3HJwVp+_FF^9;9^O-VlOjue-luKQW;fD-RlbGpSIQH!Bd`Zd@fPkM@W6#bDqSniGS_l`+W3NmK~rNwU%}|
zKzeHN6AkgB+Tq{ni>Kb9wA}}#UGMuJH;^M&rZ%#hgBEbyF|(HHTpOh-4STDe-77jP
zq@oyy4;1dL^>u!VRw!h}tm1M`?2>lZ$A3}uVVmP(`TmZp;?WV9JD4_POtd%P;=?JcV45q-;Vw5A_=v{QeMeHX&~dkQAW$
zNzEQSd`~
zO|(2osQG!mz+$Mz{!}i^jNbI`w<(Aco97bjR|{N2IAu-Yex2)Zu9VW`Al0<(;qp=y
zH(Z*dI-x_gkGWU28DBh|f~lnQ2{m}iJuz+HfN?qTb4o_9r%M?ADr^M_RYvcg8yW=K
z=6i2%Ig{DFH*BjBL}w@o%>A9{hOI{PH}CPdm*%RXE;M;!svhDX`8(8F4O3ea(Au4K2)jIXAi|MGZStGij)hjexu``3I)N#X1E
zKZejJDl%X3)A=puU=2crAkCBX6hA)#_cXJ^Kv~d2<4{CY^v+Hn^q7~BcGlNY7Ju+H
zX?@h@L_G?i;NE{ukj#sITjwjTp!Uoa75PQT+yOYTDYC!ht#}F1gY0+gMgX!}A7V#L
zwRm4v<-)kg0kaDOS(ag0+FOt4K4Rlg^w}?UK~9f8lG!&
zNFdZ0S%hx1KUHEAt&gO0(^!D0W{}_CHs-KzuFz@pw`m^{k&d%jT38)lZ+(|EH=wZ!)y;{^tFTfu2x-_evreL9mGBbpPsKIF<+_u
z797~<%D~Z929(8)5-f{IFLjZK>a&e>1v;0|`lY`YT?e|z)4##t>WWWpSr>1}c9hGA
z&<0ABYM+#YqKQnTQaw@&7Qb`~N{$W2&+MAr!JH@^kj1YGJu^MHgQ=0jM*YJ|a4oh#
zG1MCiqHZ(cO7b+ICGJm5nr!HY;k>RJHnRc>;zcvpm3j?i><8hHfnG(PHQEZcDfeRp
z`ur;z6Pyc{#77X!=VKyVY&mITB<0#zVco|a@u*b?r&Jq>a`li_Bt;=+_
zpeqv4a0k5x{~6ZEh&L0O(FeE;AqAZ-xFGf<5B^T3?-g~us14hUwH{GR41FXY*s@oh
zgvvbvq#D?&n!B>cma{R>m2Koi!PoH&+rq|2kQfU=0oVf?XzUZUBG*qx#Uy$M*%AYF
zE{(%Rzs}OH87#~^n;3sgq7)}hR9xJ)plC{`_5Svx+wZ#!tIBJ5R~h?ZQN)zIS9cVg
zwApoSX2J6??V%{I-DJ~hBCMPy8GP^l{T5_uUAtGXP6*^Q9MjctFgy`#2~sCHxHib;
zWQfEwUx#Ubj>McFF1q>s!aG=2NChQdyQm1Ii4e#+tEelq_3Ur{h|2f|SqN!FC`ns-
zt8(NvO}|O`5KSFvGl0&L`NecGHLc7azCRSs+q|snyw1K)qv@&X#i)_W)bwaab_DNXAeD~>M9f{^?W#bIvaV0dJ&4H+
zz?5j{88jYE(}UyP&c0+IB6l#qbisfj0!X}b5C@aSkXoRG=sO(#d;{$=>fw_9h9$FU
z$Yr3JbDpMV;6XI2v>88ya0c!Tpjj2|(0iI^6#x2Ie(|i=%83F2D4c9N=tVKzYEuR3
zMHG;sv`hISd#yCn^A5KIm
zbTX65J23~Qo3uznHbBZWXL(Ub@!{StQ=K
z8+wSlYz@F2zR{OKsb
z>FTo`bHpw7ce2?j6pd?FROX6aR@7a!UF9-=%{dUq)c!tXebSM$Ef1N37b(NV#31ww
z9}pLePEY+d_2h~|S_!v~=$WFUMs+b4=g4Rdvwk7XAChRC+j?cMh^*}QD5oh1b7Uu`
zR(s&VRGeq_)HQnShuLN#C^FJe10EXuoTwF6*rl%is3zcISjbl8i%r~U$aBBjg^$Dl
z$_H>-r2HjCO{jC~v#-;%y8)`Xx>uhsoh?k^1XBSciNFKaH=PaJR31+_W{F37h@b?R
zRNC-JfWK0DBDX$4U9vZ7LJ*L@wrtQ(qB2X=;-@n0w7Z-gzZyv=mUC#+%fzM+Q;)u7
zgRdOUMc90=M964$LwaQes&=8qk(ZXnJ3>pN`O6*O9Gu
zV*(dmmoGf?Ys>V&x5T_3=eS*
z%a_@^73H;Rb47=7-Q({U|H`w#kz+vTiI>`Ne;K)2>3qRC%boU~`k!P0#uVu7#o}XX
zuqyRoaFmH|TcN?vo;jGwPtEI03achWMFoZIPW1w0qad@Wo>E7Q2il_S;?ZSs<8n5(
zh6ZeOJpanlDSNcMgfrC(?QyXr^43>-2!amlM-OTVoj9O#+^KV=exv}kY-th#?yURj
zrv|$&jhRvhp9PI!RfkNq9dtT&oDxI_m3K<~{ysKI;`Xx+Pw4{SFJ@LpZqwCIkMy-S
zI8>uj0~?c1jz*{rlfVm3&@S|XgE=`c{6c!Q+NX-k8fe?WWpiFUE2Ga{YDjm@%A0h42N1$EcmdT^a1&F%9!Ioti8n
zJ`Z4f#qhi@H6_sJ^`E;ocNxRP+ModVimqlzyIN44LxLiu`+si^ZBwSX7y43MWcD6_
zEnMpM6{0Pqkvu3L-C38)%_-*PB>faJNI-*V$M7CZ7BW?kAg{M8D!?vCJ)EhpS>$=Z
zo6R-f3(2Y$0G_KnGjZv42Dphc@d}0j?2{RO4fUtqZ4g;|0
zl8upWs<7+8;uLT_?78}B`L?^U$1HY~jRjsD{q1}eVBXbSH*Brlul)L*&=2*@k--{?
z3nBZVGhVYdv~on%(P!1`nVCABzu&)j9{pwn=x>!63At2&6teXV5(Y7H1>H(7G=Y4%$~4y+jf}3>V}Kp
zBEEwZ>fhA{$82^%AeTtQrAwEt@EiX6%^dC;v!OLX#(zH1v42fD_i3TSC7s;|sP
zV~?}J9&Qn2G+|Ht3Wp}F8uUq?S_5RbrGN*9w%Go50(QqXH5sp=FV@ZRE6iblm7IM(
z6bTb5&F&Z3^{^K14
zO7Wzzey^)EKw@zCD}Z~p_N+^bs9=&hO!Hk^hH0&B@GK$Jy(A~p2u8yKcs%mfF728~
zhFs-usYRY}vZ|QaN(O;Du8jl@VGrVj0t1x6b1gtK7HkED3N(zc*GnZ8)+CzmdCTz7
zG)VoaYb#hicMW+N$ps$tc03B~5lailcl6;vKYhPRE5OFHeZa=Zo`s1VUN9S3S~
zlVSbBkd#xDXytc)YahXN(<1!&BL{VU=r!$|peaw9q4BpbLX(%qxZdXZBD-4C`%z!p
zz(?tE({5d^xIFoeIdhXCgV1h}>vdl@nD*Oyw)^R-7zTY|G~ZwET+GKq7$*)w`-!iA
zL*ff~&!$oMHPWF$DhPHHC=56V9yMIVzETC&F`X`8n?-_mi~v_gYy$M;*httzRx2z00KGJ&GXo*^??V076MW6hGsLStq-#w_vhhICKkWOJ1l
z#}YI~(w=fl=|VbYP4YOR)z5$#B~GB-6leZ0czLGd`}Jn6`00Ak)|z6pp6
zL*BT=Mr2q*4~`2jnhNe4OhT}CwVp@wZ&V`oiEdrCQ+V|R4EIYYF&c)(6pf*hrmrOL
zI*!bgw%sdg^n>!BxG=0Iwta&7?+|c#PLvNG+&bYnT}YSxXjYb?VO2OqXFS;^Dnb`u
zV8($ulxtFCDW5T%R^5Iw>lyHURURd9K>4@pw_;n#4bB#29q>RmP-02J^LHQ<3xMmY
z!u{3J5ZwJcp;j$|bKf9td&k7v_c+6B4<9u)UPdUgCrM}`ng)Ft>lmBKDn<1e!}O^;
zyC7hZ+DH0qE@qtA*Dd`t^JgYRA{syPPnFaL-raZSH=|HLHj+j+Omg;9^4ifVho#=+
zsp$nt5gn8!h6XdBGQ#jS011tTuLS^
ze8}qblRg%Et(=ciWWJdk*f!7Q}TR(5_fg8ge{M
zy^#DP=INWayfLcMiK1p)QCcuc77ejIpM**h0qkj+X7nz-i=Kq7%>KuJQJ|t<2@fPW
z&{?C{-b99nhmU+WB2lwPy*=RY1+4tW`>s8D~H_|!Sc;H-9dT5+0%h;cm&XbhE+1ifSoQ*5hp
zpf=m^L;*Z9E9QG|r%M^&e1i9m8#v3qs7TEVr_!1zH%u0X;aS{pq+ug8_wWQOa{Kkl
z;;MRz5-_E6O;1{@~WMI!GQ{$;Yn(%#o-!Xtj
z%SfG%u7mnr1~pg!*+{>kXhrJkrF9d@sMHl0aY;ZARb8DdspPRx-+%v-xPAyp1~eg;
z`O0W9HeLpcs-6Y3so#K;V*rxKoKT1vGs>eIyI>v{RKxtwsVv+HWOHHYISeh3!}t=*
zMvnq9te9_Pg{DaZL>>0fTfH*M{kDz|7pV_i(H3pY9`^1M*NeC|l6UR73dc!jP(1b=
zkIRlUL^vjBW=b~7Rv#qW3fkdXFlwL5T3XjeN$H}{bfDeZ`2+YbUYP($h6NNAp{5rF
zM405ILbXlfD7xG=N1&n2s$Ox^35=J5&3H!dMe-O>8kkirWuChZ#AzuV$|}{~qN$TG
z-i|G>x@(fba(>i$cWX5b0i)3{w$*}8Kbrj8dA!xo>e=#&kCTFibjUAXzNmQ@03rKc
z%K$ga!0S$ZbZ
z5imx&JxJeJ4d99Ss1|2YTXLI9^w)Dn$-4U&gh*^dmzK-oxK@aO>6$U;ZL|s~cT-Xc
zin~Y-flj8Lh5V=eV2bJpHb?5~NgANlKnE}yHFDO_OSmcNeuY+3c?}vR0Ct(tMR%sa
z9u*$mEMmRn%LTVSQ)gTJ$4_Te5AX9)q9zU{NH~balZ2H$
zj^?At(+p&0`Fc>mu+w#+IPRI{bZA5b2iC#o&&+}ZDwV#nW2`8qSGx8qF1(N?!cD$#
zmK;VZPAEeoGJuGe`3mL`Ub+wi$ePDw6Lx@x@#}`x9w`}oFm2_(mHVIa!6>`sm4&WE
zC{4#Wc>%Z9_@9<|fwh5xyfz)o;+X&Ft7|f%28~hk_}9nm?5d8VD0UzV>XvR&^(Mu(
zsjzBxzCm^<0SZ`3V-@BY0e6_kWSe5T-}qEJ@<$%(V*Y^2=Z9X=1>w-$6{{Xr#U
z`%Y?~Ug7JElR%b1_Pr%IME=#ruxSt;BfBM$Dj`fwCxHHByne9U4tbTJe5)%31Zyi>
zDOA6G39%H6O^j96Wz6aqzhLjtg*b<>3{K=f<8aLev@ZWPYz?H?9h6yErTI*-F`&8vI5o}WFb
z_7L>2#^aaG=v~tYn7`*5=vJ-k+FW!;m6bmJ#cep>&l1JLs5Bc{xgN{D7)jBAO8+Ff+|5Q0=G-3JOv$
zFN?YT8iAx^Ty|Cp*UQjYBw=Kh9I4>shcfW+O}cn8gqC!UYcX9z(-IecB!CD7aycJR
zs;P9j>px0n@!6#9+6(@nZyd?Kr2f;*B+%=)-XSlyWZ+P$lLDRD3xtM(*R##M8Ocb(
zs?htC>t!^otDK?L(m=$)=PRbTpVRd>P@^yBs`^t$i;gf*+N|nAT#G?d!s3%cd%Vyp
z*GaTY~gkMRp1?n)&TbU}4(O`Ag
z#J}boI(QJBVz~rwxecA9P1O(XAC$Y3uGTx4
zNVgO~U8l=q49yg%ll!riv&d2_97gYJeQawTknfBxXnv9Ey~jh~`rh&?LsG}nLUQ|g
zPxzYt?d{#1){6ddHnu5kFr#3oZZ!sI(%6yo1#=?3OvtW_QMPpJ9TO#9w;5D=88J(j
zzRV|&CPrB=CJfP~FJu>i;~=6^2=6a7uvtUwBeHWSt#03`4McM3dXGsoLCQ|e7BcxX
zfyH1C^unqy7VAn1RmLnKH4sedm&LnmgvN!PzCPlbP1j(tr#e6pw^uoDRZTd10D$%m
zZW)mDQGJg3LHLO7wXnXP2ypK|f5{C@EzxBcH0v8SI*k9K;X*ZUr@`BdqN$v-0GH}O
z&41mRrn}+!2Yj9W{}KU5V`@(kE@F-1Fp}wBbknAXZ@gYplBbzg_UY3;odBqRIpGNI
z=konEY1LB%z91RH?RIRY%f`rl5nV^udy%2va0&yQ<6<1%PG;GT{BWC_w%Fk8`m&C;
z&}aixwTmH(Ig5Ty7oR6bms8e^48|b>C`!VA3~yXg{upSkYtg#PCUz7$Or7g-{F#g
zedcm|sF{vg?{@p!(=fRLE|0S6+23^!RPf(d;u%$vt%ioobsXqE0@7y4&89J@ZxSua
zmgM`db-YcO_8_bQ@S_yH8C!G!!9lkl#JD6Bvq}Ql(=z5!E|VnP#FX&t_CfK0;{T{bX@>Ie=YHeoA{`^T4ym}bd2%R6PBk0QR*MWVgWnNcEzlIU8878z-+VORO$D9M>Jd+S
z^a)uq8I7Wa~hrm~Nr4tAF
zMYXT|0OFK7wdz@@cV=yJIUBwsKv@XJ$tSlLI9$Rdfz(mOkmhp66uD=9d8I_c8%$YGQ&m;6$_>
zfW#=KT#Pz5!KA}H_p3&Xn{VBDj0VKG9Ii%A=)A{4hGHFGV
zo~)Vg+%S#+e7rf@?4O{#{vSa}KF8Qw1%7_Oa3v({5#1tLmPB0Y72Ci_&YqUU4N5d(
zNi*-{yQTIURWW*cx(ZOT+ZE(~<@{TY_<~hPoo;HoV6w5=?9$Dff}{zNa_T=8XOD-X
z>#4<4eKB;#f;t*PL^wqb!LK@
z=cWz_COy_Mn#rdKpQUBYduo+P0Bm0HR99zh-8-@ik&oxB9$8{>i9~6-cM1nzN0}z`
zYDR10PO5d0Tf?D))CM$+q(A#luIlJ-NAA=g&!;L&Hw#Pp024>aZxD|-x`lL1#sugU
zKLp@OGc~$jauv;?Q_QnIG*_;K#+T>Ty@Vj@iEdh@DBN~Hu&xITk2((od&tO3s02?1
zl}VEvA9Q~PMNGu8=aS?`=C!Y;TYIR(x+P`DJ0N!Uq%lrvYF{6i?FaxN`N@Kja>RPZ
za1EXmVE)wv`OCwaxKsPNRbe)Y6`9`}PFLE(am20ynYiRpiCh34D1VhhMOi+RYOpqQ
z39FCL;4!v;5g4?UZ6hqVJZ#F-JNy^7-Zt}To|rDPdemycG;qNM)5&xjG@*P-f_9Ar
zdC?vVrm1+7M5}UB1-HFnp3WGx5Rfp(UE{rV0gH4BbmleMi!y5D0p%pTZ7es0NmE
z0V4bPhcG$L*4@csg`9wQ4aWUkL71ha3*oEzGYePJEE;-kCvVCTry|c2%}|SXKEUg<
z**QKf=!mF^w&2?StFl0nenRn;SI3z26tfVo^(;;
zUZ<>GWGR9oDhQp4`lu`B$Apdyw&iAjn1Y3hZqeaB_oNZFWcp~*L%C+nS^%RFUG$({
zK!Y0?Vt%|txE{(S9VN@QbQw@?n`;H*;48l>cfgF9c
z$6iR;kogrmUV>>XE}H*ieP|*o2h<-vGPYXk(CY|(!YRg>Iswb>-+Qzb6M}TdCTY51
zUZA=U=%$RA2icgeVrz=c;nfH3rep_@KXco0M|;Buscaz+n_r7mf$=_MGV&cRf@e?H
zF2uDUw6?0(mhNDsTuMDG&>gfDKPEdbCD)n
z=ssrj6yowvV%_Nq66&K^{`PjF6HcsP*j*mfj;1}s$>0&(J1Rzt;?bHNp$ybYRljJMy4w{!-sTy%O3!f6tqZ!F0
z4>9^mZ$Mj7+&4mfAa|cm6*~c7#!VfQ7A-UfPR@7G)QCQ%
znxCi3{CH)4cshiU5;pD%I^oj+DoPhxGbTt?C5Z`mB|ftpFo@mHu*colb9nP|))LVm
zqXs3m*vunz;r1DyG;IVfE6hkw+Za@(Mib3TQZpD+z{`L8o{Q2&HPxBukMl#Z`9H?4
zJ1*z7|A*rkhq8sTM)e&-{kU_?1>3p5
zG42xCycno9-(y`^vx%#sag~(hN7*WbjZtD06*#3;8gokO6^h}6ia!u0x5i<$7Mf#5
z@fhl)`K9LQglQ7tD4&d$+JY?d1g_Dr0-7HIkb-o!H<0$JRQSyHp8QJnSu=%x2e);Z
z1OV7HUiNuFgsEM>jAR+}?c3LpDgg$(yHiLq#yE1}dR#3jLs
z*NbXZpV48`^oJ)}k>r{7p*bo33>rK-T5ITTrCBsU0K6vDG3!qMvEe+j!u;d|G6?*=I6By@)FU-|}7>FrWu)=73fvfQ67EH$*xbMph%Eo
zlCBI*DI+YH4rAEW4Mh}{++OVQYBv$j$t0L&@Ka*-W8`0$eH%jNHLAV|v-e!QYWZ@q
zXRl%J8%{^#D?+%3-rXiu_?2okq%K_=Ed+?JK??Tnsb)aQGk2t)nfUQLxF
zd4CD&Gcx=@&o_7f2v%$kT$pzgT*;I}OH*WqgK5m75;MxoG(HF$(J#~$Y}5w~BB6sf
zG#Uvn6IoJ&5Kj_InsJ5-OXF}W;f$nA1@0IBW75C~s?BmUcSV2On81ShJXtK$q9ig=Xkp|5(soV)l0kXZCTXg
z)1%2nyMI5O2PsX=j`e&u@(U*@(}z8%jUq({^R5N*r1W~l@pCi%wWR>*LXzaSGaHL>(em-V;1pJ76*`c2t^x8+<_zh7h}jl7G!!c|0TT2R=wNs2
zH}sEmR?&qf!faPOWuh}25=nP4wypS7XLk3(>KQm;=M{;uNN`|Og`aW)iB$V?l#KM`
zFKp&Czv3woaP#+<<^U^b2Ec>Mb7sz@!jW~9lBL@aOP7TM8L;CtoP9|9l`N=G
zOao;cEpz}o8L`yKW{=_n#t>qa5S`=lp7JIG`G1J1z1by_xeP04KPgn(lu(2L=(h!`
zgVbH|OPo7)kbW5BqDFWTmY!sU$o8fiNhRRwN#|VU#@H-u^iky(K24fh|NKMP@0c#0hEuZQ@x>N0i%yA
zsZyaD>XPZT;?b_>dAXbId*Q~Hm`Cf4UW62$X#Wlbaan7EZy`40=$(xKu1?d^0^u`N#}G0d(Zrjk
z1Y@f@BJO_8$Z3TPVO=^pHTQPio4G;YO<7<%EWaNc$-I{;>&L1?S8Lt2edIeZm}aPr?d)|A
z95^87FQyLZ;1dm}PtV3>=j4bwyPu}+gPmi<#>%P?Isb5Ht?WcOh-bRc93mopy@NV1+ABf6DQV}W#FliVH3OX-rrVLiL0g~
zj5ZdWti*d%K=+&5ttE|OTN{vrAN|B@7=N%(+%c|dM+Uo=0ziZ-AS?e16hT^X~*v1o-PCxOuq#~cg&R1(0;-QBcf_r*CUUfOTDC%bUd
z;ggv8e3YiVxsM+%iP1Gra)ZMjBRY~RFB=NVYI`C)2cZNAhiW;weBBz}rHCdC?XX!h8Jx!y3Dsz4RC4902{pmdZ64pfBZZZ!@Q607LN_V@_C
zya)mev5HQ8v@joWGr24ectOREn*ViiV<}LhaP(}+-dV2)B$p=2uD+?`+IO7h*2)iv
ze?Tn}4rm!Oh~DxJ{ge;T;+V=m=Iof!)$)uFI-n@r7$XSN*XpDt7504V=um*PIqq2Q
zEO89-VP{o5Il4#APehp(d#VQ{=|uY>@oD1EY)VZ^e*67*W1vywS&_Ui7`wkApu|Ip~^=_!HR1;RPq
zuw55su{Ymaf+mdy(`%vdc$`X(M?r)4T&OYLpSoG~z9fquSjI9pH@C)VR`VC&!x6-H
z;6T1sH%yD?U@CE{avK@)48kcUi$n&GV2kYAw{J^aPSv2J{{HKOC~@%sKB0O%a(Z_+qP>MUc)!V)#q|CcM9&+I4y2(njSz`-Wabm
zVa;2&=;-K{ve*zHOOfDRwp9XD>Ad|6H?o=Jlgn9)1y5Yu(?
z{6tmJIKM7lzhT1$T!w33B^361>gnlO3QjnD3Bk$`y)4uVVOfq~9CWbIKLUhkemr6P
z*rnJeFc?8cSdaP7Vi5-i$8@f33b05PMg4GD>b?wX}qTL<)zPCk0bV
ztmang%c!Z{b^TF0i9R)OyfdFaf1akT3Thr6X-$9+C(2Dql2TH_ksGK=f($SSluQdm
z?n1q^$dT~7@%FiEQME@|KKG9+T!LpjZTA
zT7s%H9vfu0wDcPjp^Bm`EFY9xhajdbPO)KltQke~X(ND(I@azZgqrZ7!uLz6X?o&CKnpMZc8v{>nwcOFAK5e=mKfJ)e*
z-FQi5G_~e`^rYev#v9}>ju1yJ9{Tw4CaHmmlP0lHelXD)5oV5(%4uir)uM7l%&m@H!m2>Ap?_Du!W;*1BO0*I*j>S
zx}+$`^51hb@(*#zN{-b7
z9+3AzpY#Yd9HzmlQzWisR8>_K2TDz%ImC%B9N5li*A2S)w}Y;D@Zf>v+)r9}&z?I+
z7GjKi&%izmDJ@l#Vq)`14Glok+w|No94-7p0mJJ!8ILzgH&X{){YPy3*8a;@e!0gE
z85)KHal4Ng-LUpu^#VXx4`HS`h{?r|K-BC1h>wFuaViIgo#A!s?b}ZR8jOOx-w+>Q
zpA04xnxb362ek3X8_x*`l%NZsi;I^p7a=Fap@~U#HBd>lHO1wSqw7_eH8^Gqu}xiZ
zv#fgL^G~0`Uq7l;22m6YEnExonTYFP62#XLKt*^rY>*EN3)^?#z@1B%E+IJA*VZo6
z)zwAS5C)zbpYs6_EILXlH!Ew!fkyfL`wQ_A;!b>#=9X|)`|#nzK3Q3>GLzdeF+{JB
z10W1b2Bc*2#W}p5A)O0O{OCB+TKYt4ptqfg-p6`)&|%f8RlMTj+6b-RySs1EUwZLA
zEGSSY4pZENPjxFX(at~b?(N%DL>gg3iCEU1r}qoAy=Lc4#X(zPH@KbD%V_w?DD_4}Yw0sX8I#k!z7Cmd#K02;6Tr&}c
zit!`-U;XRxmDfuVYG_+Jl;DCeKo*xjK5|ChcUvI-xo+7qs+ewI%OGVbqr9b03XMrP
znz?9yUOi-MBf75Dh-{Rog)>+yqW>6H2Z^C&MLhC
z1?~OQJk}cGj5{Rw)yRJ`s6akq1$y1R%X2f=`io#R>ZcUsz4aL!Nrf;S5_49bY8QiT
zOT*ckK|W=CTd<9fPY!wVdq)S(2)zT8+UCpQJ1{y#;IADl_cEdGrOY4|kKL5lBVGpd
zA>Ogu7^84@A3S()tE+o-=GI6IWW8cwkITz7v?pn|p56l?lhXT;suGoy;@tN%S|@$&
z)~$()-15;st72VHrY?@sRK+XH`%QVg8VeC!#7Au6IdkXQylvlw&t*2)77VAv)G1RC
zB;t#bPJ9{f!t;}Jp?$x7ht!xM7e9jpj~Bf-5T@q}rf8t3Kw5L)H@5B7mb-rW?K##P
z`ow?!**{-TiX)L8OX{H*{(0p;kN^3d@3}F|>CYc3G;xVo?_DHhvPo$XCnuFWq7KV{
z{_yA5_uGGYefj~962zmZ)|c?ETlZGM1ql7IE%$uI=KlPnpI;5ew$(o$oU<0iLNOx2
zbS4IMm0c+O`GcQdEmkKEF8=wCf4)9DkBCnNFGx^+XF2S2W$1g?$p6x2u|2ytS#b-RP
zdv3wVa?W2LWShfzb*dula>>YMpRWJyM*sPH-#Px4=ly&LGv1y}o;K~B&9(dYmr~Qu
ze)lKx@$*-2=Pfku5)dRA_L+BAAK%a*n
zeCSSX7Rt58-hXxP^di1vhu!}A@_+w(qUV0sp9IU_5B`sD%9f3>w)~Huk3O{h)UTWU
zfBojk@fZHbX8zA_PMq~W*1>;%YwF5hThIUe3LcoVcMM+szkXr#F`n!HYvcX;>k3Q%
z*NXY|x9)FW{{OmQ|6N6&c+lND!u4f?(@;&AfXKQ(cJ{YER>5gy?FV#Bwe-8t{j9yH-QBk&Ou}Vzo$6Ow`AqY
zm4*mE*nd6eY?;WdmqU
zT0TCybVyOrDsl=32NMbeme)|Mj{&_&;1yRkrN4Z+3(53#R+iN8@GupSSpSaz=wlB^
zquf1(ZL~%dOBuI0D)2>Z{3DFmka^hDvKs8pTN2%SBtB7CwS*YwJzFPT4saS
zP`FTm+t#)L1uhsbdL?xH7C$nc6E%#VTz3YZULFfM}hBt1ESu{x<2uC=T
zVyZFzKp{S#lpOc<1n#L^fP^-3MC=d$HKHiH4&r-<=@Zx{eNJ{ZXfBVB7{+@VT)aY4
zLPA0VRj2E<6U@#1{-PV*YP?jZe$UU!17s3({`~pD#X!0-sS;I8o&s?c%7D(+9I%dn
zU{DT}Wv(qB-YjrVub3Ef|M&m8fiL+of5}15X37Uhm~KuBg>_E;k3)BY0MnFl=`>D5
zzgq}>Ym^|PghyOlyuzExTomIXA{tUy`a9~bhub?;snbhVE{LKZo)t43)kQ-28s2Cun!yV%`Xp46ONSuC!BQqfjB
zz~4U_b+M@ZTDByoZ(6^zEj73KT6z$|5fu}Yuci#RIz9s42nU1Ln04&$eZj_7$*uMM
z=dqPPALgn>udS@1p^?C7YilD=z&F7T)f`Z$PjSQ2L606ag8@^4a-tj6a3g1s|INF1
zIjNAqbPOuiDN3HI|>N#*K{wh0et!j-_v4K-6>u5F`uOgFMST54ekf#i#9g0g}3~
z?IgDlzyTC^C*Vwr0_zcTw0SKRi7^RgRo*Ne;3mb8X9uD}ZVb;y{Okr}%L|p1stX{V
zE(izeMVfw~6w*MKe3eyOTbtZ+EtLs>qJZAZ$?0fT(3QSxKXLdO>V5i%#yCC;;INYG
zHN~abuU@r)>-qo@)nsqG2e?637VxiT#@zCpGi%o8lNZ=NwY{=hz;<%56QDvaM(E~r
zV}WhQQvBr_aQx>QXgE5i01nrC5)6)7zQPUR04C%Mh5eed7cGi_RfE-0i!u2|LOy~n
z{njqX9PBLg5Av9~tc2Cd&bsI6`2Z$h8QNV{;_z^CV|xl8^T%>AgmNUg1@i^|Tm)22
z6ciL_bx%STc?3r?3*cm*_*Bd5=sH2`5=%}7Pkdq?a)d1uGP(Ql9P!qp
zi^>wIezs>-mnm2DtwD@|@<;b7i<&N!E}a%*i}bd|zozD?@i0mvdWht7qkTVj_UyLf
z?Fhb2aQ^3S9xN~WSXkJNtKWiwv+~W7KtXZ)C7CYN%i~k#@-9oH1PWe)P8jfs5x4z%
zH^IEdZ1nQ=E>y^v8W)RSWW?E$F!iK|ww
z-1aDw*iBlLz9~6Sg#?F&@}$P(zj^Z!3wQ34CGF3`%x(eXz%Atw6f6#AVeZhLODy1h
z1@MB0qkaZ79*PU>E#~kP{exB7Y9S#NS{ZUyT9|WtQg(mkB%dMAFH@
z&Z5)CCzfZrWrDg7MdObSdw-v}d|)?Vy%tEJ^V&DJ@t`xkc`Rk2=wvtpG01z9dTPw)PuibU|7&+q%79ILUe-HtF)?7%YdqQ)&7J!Q
z-Q|wP`~W+9^M_Ztgdc!hecwEwm1-*oO-7=quPq?L@^q)fGa)He$p5q%(lL3+J*F(p
zAvU`ktGT8p9=SdOel^OQSDDmK1DF;N^H|9(?pUNZ*~#8MNp6U31p-DmVQX+)zHUF*
zW=t>>0Po2OcGodqXdh_sOp^WLj?YA2udnEW+M)<`1-|UVmPVXV0$M}072=nO(LQ8}
z4$2b5>DaWED+9ATepL{4molQda+>`&G2dzbT#=Kai+(A6PwwOu7FLI&B38``Y+7zk
zPEO!}8Qwh@hm`CYy=j|ZGlO1nVKfe1j3)3AT>_ro2q7HkkDzrc7EBkGcN0QhywR(o
zA|)CI>CVyyzzqPv$P%o+jzbCzLQ7CJ!ySn;5p!>Vl2@$@rnp3+@M(h3s@LigKQNrtr%!#3j%RE
z-(GFv^15LmS;X!u3D`96R06KKkkWhrJb+fK<^^q#!I;U9m=*|#&m@$Idl(!SClFy%
zvUBp8)2A)9)@i52Wz&StsNaWb)-PZFSz4_pQ_#Fx3uUlih{?d>s$qJ-N6{v*i=
zF#}Yx17u4KGnbZjC(dDwoQvJ&@7`-=0s{llRoQ(PU}N)ti%yga0VghL@%;G)KYAaI
zxqi{*(CMV+95R0921mMyxZLI{Ur+p+lY!Z4r)I#zI!%PC;fhOghKxXii=*UC@e
zL)9vA@hS!F3$s@FicMkL{yBLXu)>cSE>5`48j<(FznC`{|K#&F>!9FmCEqS)1@;`OI$xP%ljBdCjhy{HJ3RIIeKlBwyF1nD8
zry>Nj`W7Uk;5G~;q&{3L`DtJ|!mAPToC-_t=usPxAeh5fXdLxH7*TYis;a6~Jmn^g
zkJ}$%1C-#PzJ73a26hJDIS!GvcRLCKrRJTf$;!&&e#+4%4}}7W4$x;D4wV4JPi^aW
zrISNw|7k2f5k&4SUp{}H!_6H96^*dNmirxh4;)ZMG@|0@?AhOt&ieZMsfQ(ou-M5L
zI}w>w2`!@X&Yh~nJ>?Y
zM^{#Evvm6oi$B|MzRaAYJQQfuGNE%nfO-5METZtqrF!3^bi)dccNwq+qM~c?-uDAO
zrS#t+mo;w4Ts{q@
z>gsBYw!e!-Jjx6K$H6R`P0nf)xgxN)-Lr~TQ31z3Me!PXXW~K7?fFT#XdodgSXw@X
zu0c#(ydIGbx2`Hl5?TdJT946c_t6A#g=Z*+XkQYhy8}<2JXs4=(aK69pZICk&HbAQYy7KJhuDi38@*~_zL&&K9n
ziG@s`97?-H=Z#2}N{4r*nfAGEE5~Jxv;U??C=WtkS&2O$8&q>Gg(24vCKyS*xMf{v
zmW>gu0EfR#c!jElNHHM%MafnMxHd)6k{&Yw&z3(uPO~4Wex^?@BO|jRT+fKzfEBh@U($7x#}qbo`di3RzIVB(U=ProX=sj`kU4N1tM3NX7+@Ds
zjNttczyyP_A}U1XhtkuDn8W&gA8YV*6)0MoO<&BsiMBr+w}4=a^}`z`_V7;Y6%V-4
z%7k1YLJ%_CJE%t}^Gy2tNLOS4i^ex2ZnabX$p!fC}SbPuIXB7t6(!=)VnhlXEQQl%=VwTg!
zrK_7|6Ec@HiN`-@OI5J8S3i*|TPm{Dx2c=`QSn
z*9kCBDT2X7J5vb!F-|{kq0Go{Kg2N}!x1`vC1Cs>@2p}G?S@dQ^uy&dT|5LeF19VX
z0;=X|X<*<8{W>XO#*ZIw;=N-1`gg(oKI$T&hq3q8Y9>ThrT=jTk&~B)XD_;%txA-S
zdUa3$YGgPk;}RHr6Qrn&1t4P8a9PN_IuccYJnkh;L?$^9y5m%XBr`Yyyf+_pFHS+J
zRWtdN1a2b+H_Aa9(tFz*F@Y#&U~$yr;ans|TA?Jl;td_Hp=3EoB~XY3xd`)zr`j&4Vj%nwrO=*S&LIV>H)wQ&B!8Kt2JQ}lKm(dvm0|gu0yvS
zg$gaT5yR^Ya0;K%6FR7etfc0BLQ(Vdc0BwYI01*y@#?=uy&r~Gx
z^*rp9572ydL37uFz$IKlh$cXCPqRLUu8xiZXm=md>xAYm-h}ZN
zO^?Er==ZG|y(j`5F|rCFsB!$vP)w9@C4kaaH{Yty_Nw!SEL&oc@p81#A|@qkj|fAF
z(qcO96Ye(2l8RI9%tPE=+Yxa|fkO?-j1JZbaC-zhUG0oDYz
zd+Y99yXrWD%w*7!lMqfE8gCNBPUH4HJQ}`%NHSiWog&2jUL7Yo~dS
zvt;z7Ck?M4GA#nnBDJGsq=#B}_;Ly$RD7yEf;yCjz;jMLG?EDTz6makpU^wEp)nO;2y{aCA2eSmkiq_MNY7H~nxkjNu9=UG`t7v{^1;4A@r@
zxbs_2XMP*8-_SUm?3|e5*c}ezlQxdyQPfDW4lCYn!~kd*=iFyWB49Z6f{XJT6Ue15
zeVH>5W9XVt#3l~#%V1?=nnxhi{<7F#Drl{w8sqBvoyu^1bEBCz5%5@;C|iUa#3fm&
z{a{mNuU(-*9l_rS^oh-lM{sVJ`{|+>e#^%pHzw>HBNRM&{6Dbc6~qS3s4!#|XBQVg
zkx`WwWYUQL)FsMSvGj3q`L$41tNS?K;Z=O4pp7?(JGj1t)32~(!KyD-8a{!j(dIF6
zC||%A@A&IYV5!v8Hlh$P+}Mpl;gTqs%#;T9T@FX>{v*Y@*va=K1HkbqU1^6S5A!EYs
zdn=4lQ2?>L6WF`yszv>qxg-rH1jzYH?y)z8&kPw@c_ElUyb~6sWN4f}KhB`40hYip
z^o(Np!)ZOPl9)TREVUOCsXjvfAplvD){rnHa5QhVvd6!Puz!fnEx9!$kTWTgCAT~Q
zs7iMePxe|l;H#Efh9pd(><*@}`e^AE3R?zAwe_K@*Fm87pAhUxDFsoP78>#_do!;L
z%1RK9>CmE;YxN?O-k~^-RaY^$ZO9x@QXyR^l$8~8-?*;%oa~#5qUS3xU~1N(HT~*!
zxSsL~c8P@=&j_68q1ZH9h-%h*C7S~#{00d90oGQjbxj!qU4$cCQ8b5ItDAN)h9f-0
zzI?vpW|jEVEEl(}F47Wi!)YXw2d16HoCe5(+CsSjbo=j&^AI!scv&;HfH}9|Kh12`
z-Wf)aS^J7e;ihz>9}R<9Qod@9DB2@IrABD6P+^u8rJ}-jXFPoPkS}_XU*Tn3bEa@KgtB{)Z4>Op55L47jzj$}6^|YJ#3=yv
zhY^@#4Nz=V;+B@+e@Od^9w#<6hd3!z<-!)>Uy()CQ6IY@}b
z2$hg>z1X2xwN_A288B@)Ce7>FT|;JXh5^oM-}Z0N*$Kdrv1;{dku#4CuiwWWEwxVCj`d70=+kRlaQ_t0cD@ognnb!Dk5#X?nWoM&2$<>ns
zNX(~$Hydce9A!Psfwx;~i2|p^MY{8)AHL+HPqrT3h3*=;;OM9yiz;dD0j+dL8di5g
zBPqD+z~pY)Zh%nBoLfO$d@3l=Ww3LL1kuRY{GQZBe&EjMZ`2Xcf+5B{^u37?1KFS-
zew0?
z8B2)O0G6lU+g(A-
zo70bE{qBN!NpRetxI@8PVbh9_D+r@a$fWzQk+zWO0^OrYhAq~YBd~AN3^$V}*fuz_
zyRtcOXyV-bdbc@V%-k1sh&_&|0sJfrcGx09K0pE}cUk%#?<4{ptB8Gg1%3F9X&tAl
z^7Z%QD9A7%VVd9{*;{=;BE#rc6e!-Vbgnid(??os)w;?`|8mM@&_ioxs
zY=n0W)fpKX?hH65+qsz3OW(a~7x=htdJV%%XhPeGAT7W
zaW^(7&d5;q*%`MHxDIkiFK<`Kwn;5QR*c@g>d}{7L2j
z-;V>JCIQ~xt93+dLCnfBs!8;PM^_abrHx@eg+VMyc8-knHh7|ihN^Vzpxng-7zP3C
zvx8H#Ra9no!hHYRncih7xwZR+HV+nct1^ko)GCoAcya?
z&L@Ws0;iDfR)|^9FuwdYq5cy{DtU)s6ErR8-`LyN4Ox{wM~mk7aX|i1LHj^zGyhXW
z6Rca1+3Z=&IMWIh2rzB{iE`3ik$JNkU_{j<$+7(~@5u0_4Pzf*V>3CtQZ^lS5PNW!Zy)JArr;`meVS6%2#GK+tZE?i%XrK--A+{2ww)%n64F2A+&>D
zfG9!ZWM^kLqvRynuiQ9CJpCa=tFPX6q=$#HSDRT*CvYwlwT;Adzdxk9as3q0{S~L2
zJjrI$R;iBUr2-cUFs6+_L`Wx#;D(ClBX(am!s}W-zPy5J5>!CW1KF}!=;BvgqPDUK
zZ~R)i=)a{Vo!ccXEy8zFM3gCY2DrhVRQ9<~UiyFfLe8DwUbgqw19oI%Am<0TL`FAd
zZF*<_z>%&6$2nHy7gl0q?W1BTvuMq9$9S|)pNbrw-UQ8*Io!V=BbThy^3e)-bic_(
zu8D@=VNo>YH7K!fUBCY6WtbUGyjG@&1BcOr;&OD^Ll_f%5P0Td*`8~tS|Y%d(CvW*
zaqy*8-^kpv%v|5P;vHJJZ6F`7)b>;1jtGYxQsuM@09fIhD{<5y)vzI
z=&X=%)G#PWhT%1FpYQ|%`9N?2jY*3^9e>dPh+PDxS8CwXd+VOB>Qov!B-9|e7Cqd<
zks2+IH;1FK7x_eg>**osl`{M%K*M+tYY53iVxu4*O8FeiX)Sbb&-JQ+ImJQy95Do1
zfVk)516n6wcu`Qe4t$eLSyXpoIM`E=!g}%M=sB?n9|ECPk2<#*$O;qgbz&Z~n4`6u
z9NGxdK!9`0asmIToayR9gCB28yl)y9)gdhhW0SwYkve*G2p48hWE%B&`(V!eaC7LA16h8AXa|wyLAPp)C
zl$u2R5l~u`=Ec-Dp*S=`Uzywjo?&-)bKxHRr7AulkOT;eHsd_xysM{Zvt61lK|toRz>n1|4Z
z(`SZuR|B3HLh@2=qo_0XVc$snImyP|(06td-z*%Nf#wId`3AuN@DA`mxsJb~O};@Z
z>2Oj$Hf%hE7C;Ac^r}$!*?fLE#YcZVAZx(J4?x5uCnYtZ6(3a_`iUG*8m&SB+?Coa
zRYDQgr_whhqRyG;D_j`-sUJZj0?~@_F)Tzgu6bm)0I2ddc!D^B)$t^ub?f${PDqf8
znc%5}0i#fz$?n*(p=#OC*c$4cuU-|TGeRFIJiTTNpNm%eA
z$vd#aC4xYBK)f=+Z!IR$jYXwruIje5PRTX
zigMe}x-nz+8or&H25lJ8#0CguY
znUz%?1qA4R<}+YxL8zg~^IA4M2(Q9H96EW9T73Ii=FOW`SZ+wNTxUL`m9=?0;@0qP0VX>a6mTIpselH1s{OW^b?zSW
z88Uk${;!3d@}wCTyFw$X4_GK^&a%Ws?b;#%Mz7tpsUsmC+mA$AK=c&Q#TM%TtSNz@
z;%RF=&-3TcN$ocWJ}=}37@ir>L=Xgu;j0`(`v6_OwUU4nEaLR293;^$2yxV^$Z15=
z;q&<4!>I>`VXIdA=g%ZCW56ywb@Jrb3Hn&;{ix7Qrk)*!WQeK^LiF8Zp4{TA;n*`7
zDs4#W$n{hGy_>b{F04VLeF+v-X|Wr+C`UkuW>b{0|3aavp!Wzb8Vd}AP|-R7?pR8P
zfR6K(v{t8UqEmBsba!_*hJIk46d%#`s0Ed9uDeg5zjSYcDxe^Y8RoFxVzj{~O~P}a6(c_`%=GTP`P2R1+ErWinO%p}
z$O52kEdT73fGCDjink-j2+nbI1$*@)6Z?!RR#C~f!*?){N;Bx
z>Xyvcbq<8Tle1+$3O3MIEwIXN5<7Dc36M$Whg8(g!E2Hki)5Iw1_#LQN8u*WpJ441
zpU*=3+ScLlN9qfUn)zM%udhBes}dm}wL9y~nKRuO(nFBovVlcI>0MAVY>ON#xflIU
zR_nO@yH>rBZ2Hpua&m%NPFA6h_@kf9nWd@LkXDB25p>*AIz~abPhgp=fN-RF{`1-;
z_k$N;F-hxhQk{68b^|&m|C*yq7h(KO5hjwX|NX6o74jK~oi_@m*aBpMCyF%38~fG>
zoglJe24P%CDwo6$K=*?uG*~3yB*Xx^U!Z0j6z@&LQA04F_K4
z2%_Cwxb7jUOQM9amm^4W2T+jAq;H0$q3AV2T6Jf+xVUr!a|?yna2^Y8#rqI+89_Yi
zekM@y(4j5>yrpx+Om^wGN#fQd#-FT|`<=Sgj8vW>~CDbMx!km6XlsTiO
z*zQ4W#VGeM^vvZ7F#yht5Y@hSb;YTvKG-$o$`VN(TH`<+qI$#u%>)j$9>f4N9M7jp
z^-eYwmM+*1gtJGwKhT-a&27pf>jH$woqdJpO5Axl-)$fCbb&1s0|G=IR>`eJP7u5F
zXo1QWbDbcy
z1SwC$FnY&(%khhW8946hg#BGd939?;7mdITna4sV?1JyT1=n3*P&5&nb7rvrm7iM%
zqu9r#W4fdJ_@+m0u68&TOo@cDChoW*Bz+eaF4L<*k;a{{W=x*4^AYXTZ6AY(;%99<
zqsWxq8jmp4Ih@fpwtF2M9UZSS&K-a6hDhZHRw@LQ)*-m>l+Y0rQ|s1+);z*(WO&|$
z5JI`x4IC;d?S2eC6C(vzDu>GPSzZ(WnIma9~mgdgThe>-~9)9qMKhzll*-SBlFa
zh|AVf=z*R$vL^}RBx1YpVvD@nQ-C@cVqPcEutC1(7FVYoPMwWR&Uc{K2{AA#L&u~d
zfeI4$ewc!5=$WY+B0v?ks6@wY&rnfd({*T#Xbo$k;{bI3Lvy^|NEB)|2p29&^~M(i
zKZbwLFth}mCvXUIQUp4pa3ClM@v9it?&;c?xN8Xm@GcY^?k!i|u-ZK$A|eRPkUDag
z>=#(a@i@pPYiN<6@iD=@Ln)}iIN23(cJk9=Z{$Dh_qD56ug+PBivUPLWBQ@YWLKmZ
z_tSAPKp7$<*L!%ZOGT}O5s+~UgIu*iy1mtbKs!NZEmaxS+-K3Vgp*WYb$9aTfqM!fy>jtD9K4Rlx+*_Z!rq
z#UkaOJICJj>Q95H1R+CNZ7hZ;R=0H4qp4=E`FRF1mgmXTiN9j3GVm
zAYz-Ukm9WyU=0o25Nbavg8&SKo6dF2vv7fRT<%B;1>|;oH}(%#QDeLtQSS?xD;seDuNlHKFyj0
z6+9tJ_{!aSRfV%{K(vqF68>l+pU9KnjXPBc{JAY|f-nx5I4USB&db~U%4KYCkyydu*t%5OcE3yNMDmvhzoCSd
z09NN?KN?9AFXGZ47g&y3sa#AIB_I)?w1?-hh;oA@m@BLNihD?KP>><+EKWjJK-oi+
z6VY$Kf#UPE6YyrL2I?F1hXQgJO-b?!@)vhn*wSkbG4$hw!5d_d@hlDRrl6=pL*De*
z^}Qva@vQcC5V4`KFPFu`?L?zF!cd7Xa@$`6qClR>Xu2V2@Q>_9L^9)kt;wQ*wn54o*8~0d0banCuXeU_zfaGAUJs3JNJV
zGTk2L?~)Pnc&@`lOkFgt1BDOiYS5P>_gH>$yMwdWnEW{YMIAN#V}}ehDdo9WZj+#v{|;y0yHxfk8GG2oQi0mq41gy)9LFM+U}5<-bNZQV4?x
z9pfR@jX?$d82fE1V*gr2+HEVKbGNbcE!j?T|Hl?!FTK+*NK{m|s<^l~Y{+!V;zWT<
z;-_L09PSZhvK%#_$Kf24)R
z&=}>TCNKh821|(FN|N~Dt{PV+J)xBVsGzinqmSn2z>1`Rlv-S#IWk;ftk!0O6qLyD
z2R`~XFHa7D={(luVXXx;;Upw;p6Cl83e%ZAZB?SGPPZ`W1yy&-H-58RAT!qJMSTKR
z2gg;*wph@=rm!K^D18+`;B`Z?AkX;(3&
zbT7L4M78MWj}Q@2w7vl
zC)yuU;)RU3zQDZ;o^ZfS1x2k?rX3oU_=u6J=%vVrEe3U%xSG>3a0dERVB8XipaSs~
zx%r_+ZqJ@Q5nzDi{l=KZPm0bQ=^yLNXJb?NgiM}md?hBPotyij{@Z}O+t0qS0NmDU
z><@g4DlY9YC=XD7!a3a?9UWXuDhTifX~cwkFOXJm43-lm&m(n}(c^cY==}jOmfZXx
zP@GiGjK`T?CZPuCdFwxa)?zI||HO`I4?>vqfr^k47*_Ve#>4(cY1~#X;TTMW
zDOB*o{0Zd70M-!90eU=fs-$mjeOyqTHtafw=nsqiP>N>715rr0dhEx^`A*gI8%iLL
zBq$Ce^$ro{6UZ=e*l4^9_4QZReB-X}*RN|@ak%)jZq)g~RWaZ2$F@~c%msc=P^R0q
zAL$zE1VGo8w8_T*Ezfz<#{s%FMtq^MFM*6m@sMk9fcy{)0XpU0^3)15ekxp|!~y&)
zMS@C=7VOtBeY?#0aR5IOO*j8jzDbH0RaMY-=BddM0ofr?6Y~DQ{Zp5p6fN|_B(t%(
z26Z_nxIGmHV2az5Zq_Z)-1(+;N)lkeOxV&t2L+f7_+Xg*
z%Fq}QB!q*l)&7$0wPD^94Wxp!()#t{$GX%7)r<@?gAa>6?0t$!Ez#`8n=K`E;qAd^
zuu|I_`C-D8?>-;0cA~g%y`E)>J!+8D1K%5b!~3Iqy}l0*%$zglSf?P$bwMpm_IC*B
zu?t4mynFX<0#33BdZ$VOz#25VcUKX+kAubC@8REn-)YL2u=~;fp|b+0`(D+w2Lan+=X(Bm_f&e
zvK3p~S6y*`JUh6J%i%;#=K^&?KFLu8{!zqc$-~UTz4G$5EnNTHbhgT{|HOJjs3B?S
z2n(oQ=V9-`qi+K?ii=k@3@4oR(wthpcD*06s!|Z)HD}iOVcc
zx*E9q6c6u|t>uKz+H+Vq=~-c+B6^j~sMVV{A0){dx-rJK4l|IbEd_yRvhPsA_~4UuA1NVPm6=
z0U$K|tb)sz(WN85Z=Yp<`ud^SZEaYrNSxb(uJOYziFqyW6-)u`az8bo5jJS%wbSfB
z7@L~uo-x18tl@Cl8Du(T_xm=%9B)>PrUOQG?8jku1?XgLIzH_np90vO&0Y7F$CP#2
zIr=@UX1#j#YN2c9;RA;=S8(=~rMKZgo)e!}j%;$|*>mK%)q=iSP=;h;qZ~&fA;25%
z?OBzZ`XvzkMXozMJFR{}LSh#h0oGNAY{b3Whs_Wfh_
zCeO70-Z#iu+{gP|c9)#1Lf9ATc2r)GB{~?Y)G#dFuVcU)TF7KFEdmC83(Njgv7tPt
z6xv!2m^>b5@p`QFUR(mDqEJhJXUP~%HMoZ`I#;8kh!X1N-Dai3g*}CWI
zJqEVvAr!B%+D>4es2K&GnP)U$fV%~<1(oK_ol;WR=sbdk1<#BIYM0PxI+8|5sBTQ1
z@cNI>>0f!s1faIj;K`~@Tb3?;c4T1T^w#)Qab@0!u|IzND0QaJuC|XnfyRRaNzWO8P?NErbgcPp227bw5N>
z09Y7ewKp0b(FH8}ezR;9%AUXGJQ9Hsl5Yu+=A4HI3(`F%5Jj$e7EXy4Ha6+sBvA(f
z=#qBO3%5jft3&h3bMP@F2|`-*V3~UUx_89qzfazP7dL@a&W{osyhT?{ZF1S}MZS
zA?ApTtFxr9OOnD4CrR`|fds>nG96X>H120g1$7OJb6aCOL^S;`^vw;{De$kLx7a?D
z>tf>;Un6+FG#(#5MRPBir;M-acR3s?ZM#*OHfgDB)wikTE*rls=u08U2Q8mvdHb|0
z!>;8AV@$`!F4L*6GPjR;0N*a1BrWCP%e`2A2J*5jmgQ58J2357nQM;kvYe`MdnG2n
zWx0H}k~csH^7=#+0*EnIXaWPI0r?rM(&&^|5`CEe#t(C%F??vGFrPDj{+(7q(3~1z
zZAqo}#f`1{Umz`O;!{juhdgyHQhV
ziW#n#g92=&LcA^4od{aIP{?Zn2Dj=$AoEIw%4cV7K%JPB6v<6ylP+yy6e(PQJ`8DZ~-5O^lrL~U4t(!S>LSNeV<8L^^Om%vq6=JcII!hTH1f=cj#!Q3>&x0Wg&Y8iJ~d$3R0L
z_M4m5Vh|n5Una&)w0Bg$7N;uUJB8CJtL|d;tv+DzVU&O^)
zfyKiV(c)mo6>sirG#Bg$!BwWvmDLm1=#4j2Q7M{8E=a9EkaExjT54rNvYqX;4U-#C
zI0{FD!RNuTi#e=(KhZsR`?c4vT9v&lURtVb_Uzfy+wUyFL|mdqG4GR9FdqoAI;hy`4iI8(&k`!Lvk4r4||B+Ykv}JfI6IB
zi~+CJ%be!jZr^j{2<9#zX=*xWTE)^B8K4#%VjU;<-PvFoUwtrt_{FH|!Nx|g%WeiY
zx8J1MRqt>2K0)-B_pN*X65HZC_Rb;+wmfu~IBz8s;;(R6)grpDI?7kfy1JN~Tkft%
z4N8m_$?6|P^J?TO*2M8#%!K5d4%{b$Y$AHTs8sZd%I)xm^l1tZ5K0
zGD*6)i}Uy}2AFF_h}71p53o)au(#%cF0^@8OZ5a04)I+_`d#PzPBIq&2h)qu*l57d
ze;whqb{JSr&sCKsGFD^0EkTjR8V_sE_+AEk=cxp!ohWal828f`!Mm-4FB)b8^(a2s
zcB!ee{NKKgLaTu`u~fm7_%E=(#EZEPh_(ouzV-LuSfT=<+1K($@GV_~k@tVSTuy`S}_U=l|Ul6W<9;Pgv^GBpo{V!p{j+^OBJPbAw;1@
zKR#U`XK>W-+3F&OwKYI~*MfwZp4oh|In*L%qIkLl9|yG$Ti`v`(#
zT4qRYF6unJ(D#taKvj}l&T7!H@I?u~C>*Be2i2HEhV)-rR8#(3&MpO82mm`cFGV2d
z*=JlWVbps7?zuIAO=ZcFfJH~Lm{KKDQc~s)L)PdBkxQpdv>;_A;KEpkj8yLY<8>t%
zyzJ%eZ5;m|Fd!yDUywWX(esOcAzqW_#^QoUmbAg-N(w6N0TNfz;4`E(Gv~ZxTX1`u
zUh!`P>WGUI+ib_fqc-)eOTePOZwtzJP+yXLIdLf676WaO0?4#nyja(z`xbOuCO_`)
zVc;@NFV^O88!Ns|NVXcewE17V`58AI|DeLar3ConJE^cP2K#wIIB`9-{@{L;AUg^q
zXR1b27$(1?h@ajKc*QSJL|yev<5vtI%^nHUgX9-MYR{NadbaSFa>+)c)i_O<6QAdj~UlU?}*V#
zr)Y{byy#Ms0xRX<>^zcWi@CvA17H*=RRop;L0PwS=~C5!=fXfg7@^TH4GB#OTZ2hL
zUERT(wk%Zf!&|`_L0Qq_5O-N?UQg-D52W+bKcK6n5Y7NIRhD75*>Q%n{D2zH!Hb00
zg@I~^1RqH*y)vW6vRemJSeK%!8>Tn(vsFb_`$4nO_s%A~9S
zz?|rkeprqK%__GM1#a8D=edGf+ko#JpzYG6&nz2ery<%HbtF)<0x^#iQCa%H4?P1@
zP>_k5{5%|mZhG75+aqrm{Ok6!O+AU-&KWN1Hay6^Bn}b?{vi}qDlBrq14E!;C@Kux
ztmB65XxeyJu6$`s!LVqcIT;SP=w!2SvIxjaX0t5k-)ufCAD)dgrx*
zG_j#nMHK1MI~b+d0BMIJML|KS0@B-E`xx_*$otLQ|9p4m{%68WVxye%{GMk&d#}CL
z+K?mhA4DyvPGUXboCb&(gIO!)ZPZl9%TMcQL!V(7y`U&fJQdq$jSR*s`%nJ(RT&U{
z&U-`Ln`rGn=PmBQlC6VC<^GfXKmUOHg;!51j3OZ@irbTR&Qiks)O$?UC?8mjly(u@
zUB9c1U~#~_iz1B6?t98*DL$NgV@HehQUUbdP-#EuUO6T5JY(X(UyC;uaqRDB`|FTyQt#RjKj-b6$4D|%qhWF<%w<%C
zp4@iJ>2Q2lkeGTji73_6~3H
zV%+)Edj=M~FgqW>9|lQ0Gpj?WMJd0z+>rV<6vMLaIsX=EI_rS%_>5>I)lp>13^Jw?
zk^~0XrGd&noZN<0BST^`)gbdYI&kz!q}PVMxyg{2Dtt)f66#(acr%(qUHp^O>VE44
z2=jvBF;$?r2)1y4LY1EoJ^;Yo1QvqhGZ?P{9g(3twz&e>!=(UVj9VSUPGQms2wVE8
zP8jiV;WvdZB_vwQtQE+4UeBB9k=fl3d&Oy|
z)iv^~LjJ7WhMP>#tWoA-yTLo_O*f&9kfTa**?4v8(oX
zAVEU|lFwa;UOKwhGyCkJQe~}xFx}7I-kvjJj2MonQn)ICfqcfW;)Sj6A-oSe*2ZCcU<`?r1;nBcioXxIaVn5C
z8Mh8S`eF7{fMD8kz|L3)IHjF3sh9RZc-5+%_#%;mZC~h@4gHp23{a}E-y4DV11O?`
z2S-eilV4~hyK^Mhup?`LCrQ-WWE!3HqvJx3Z8Ha)Kn6xd@3U;jOrF-$p!WewRYnlL&%A`6F;;FM7rGlwe5M;Jio~12-KiPUEB)NbJ#DC#pM;aQGR6
zN=&^9mroGY5d@SHvIRo0c{u9yUg9lP-WnB!BXW<@1Jb!!Xd)O_nUtMOzd~eC7g}~>6oa~26=r?Pk;W4G
zGVLS=5D<{QNd6$Co$G3(wHN09j2Mkbgr;+EC?H0