From 95a210b7c098cbac63f740cb27acf7c967e29310 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Sat, 7 Mar 2026 09:14:06 -0500 Subject: [PATCH 1/6] Bug fix in jvp updater --- src/training/updater/jvp_reg.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/training/updater/jvp_reg.py b/src/training/updater/jvp_reg.py index 5c4f19b..38c9bc1 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/training/updater/jvp_reg.py @@ -101,12 +101,23 @@ def _compute_jvp_gradients( deltax: torch.Tensor, ) -> tuple[dict[str, torch.Tensor], torch.Tensor, torch.Tensor]: """Compute JVP-regularized gradients combining current, memory, and JVP terms.""" + # NOTE: torch.func (grad/jvp/vjp) requires the function being transformed to be + # free of in-place mutations of captured tensors. BatchNorm layers mutate + # buffers (e.g., num_batches_tracked/running stats) in training mode. + # Force eval mode during the transformed computations. + was_training = self.model.training + self.model.eval() + + # Cache buffers for stateless/functional execution (e.g., BatchNorm running stats). + buffers = OrderedDict(self.model.named_buffers()) + for p in params.values(): p.requires_grad_(True) # - Define loss function for functional API def loss_fn(p, x, y): - pred = functional_call(self.model, p, (x,)) + # Pass both params and buffers to avoid accidental stateful access. + pred = functional_call(self.model, (p, buffers), (x,)) return self.criterion(pred, y) / self.cfg.train.grad_accumulation_steps # - Compute gradients @@ -141,4 +152,8 @@ def jvp_func(p, tangents): loss_curr = loss_fn(params, x_curr, y_curr) loss_mem = loss_fn(params, x_mem, y_mem) + # Restore original training mode + if was_training: + self.model.train() + return combined_grads, loss_curr, loss_mem From 73ec3a0f88c9e8867752a6211ffdd63847ce58b0 Mon Sep 17 00:00:00 2001 From: Steffen Date: Wed, 1 Apr 2026 16:34:10 -0400 Subject: [PATCH 2/6] fixed a bug, where batch norm would behave differntly with jvp and the other grads. Made everything but jvp inplace to save mem --- src/training/updater/jvp_reg.py | 117 ++++++++++++-------------------- 1 file changed, 45 insertions(+), 72 deletions(-) diff --git a/src/training/updater/jvp_reg.py b/src/training/updater/jvp_reg.py index 38c9bc1..8da8056 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/training/updater/jvp_reg.py @@ -64,96 +64,69 @@ def fwd_bwd( loss_curr: Loss on current task loss_mem: Loss on memory task """ - if hist_batch is None: - return super().fwd_bwd(batch, hist_batch) - x_curr, y_curr = batch - x_mem, y_mem = hist_batch - # - Compute deltax direction - deltax = ( - self.jvp_deltax_norm - * (x_mem - x_curr) - / (torch.linalg.norm(x_mem) + torch.linalg.norm(x_curr)) - ) + # Current gradients ### + loss_cur = super().fwd_bwd( + batch, hist_batch + ) # fill the model gradients with the default grad info - # - Compute combined gradients - assert self._params is not None - grad_dict, loss_curr, loss_mem = self._compute_jvp_gradients( - self._params, x_curr, y_curr, x_mem, y_mem, deltax - ) + ### JVP gradients ### - self.grad_dict = ( - {k: self.grad_dict[k] + grad_dict[k] for k in grad_dict} - if self.grad_dict is not None - else grad_dict - ) + self._compute_jvp_gradients( + self._params, batch, hist_batch + ) # adds jvp gradiens to existing gradients + ### History gradients #### + if hist_batch is None: + raise ValueError("JVP regularizer expects history data, but there was none") + x_mem, y_mem = hist_batch + + outputs_mem = self.model(x_mem) + loss_mem = self.criterion(outputs_mem, y_mem) + loss_mem = loss_mem / self.cfg.train.grad_accumulation_steps + loss_mem.backward() # adds the history gradients to the default gradients inplace. self.loss_mem += loss_mem.item() - return loss_curr.item() + + return loss_cur def _compute_jvp_gradients( self, params: OrderedDict[str, torch.nn.Parameter], - x_curr: torch.Tensor, - y_curr: torch.Tensor, - x_mem: torch.Tensor, - y_mem: torch.Tensor, - deltax: torch.Tensor, - ) -> tuple[dict[str, torch.Tensor], torch.Tensor, torch.Tensor]: + batch: tuple[torch.Tensor, torch.Tensor], + hist_batch: tuple[torch.Tensor, torch.Tensor], + ) -> None: """Compute JVP-regularized gradients combining current, memory, and JVP terms.""" - # NOTE: torch.func (grad/jvp/vjp) requires the function being transformed to be - # free of in-place mutations of captured tensors. BatchNorm layers mutate - # buffers (e.g., num_batches_tracked/running stats) in training mode. - # Force eval mode during the transformed computations. - was_training = self.model.training - self.model.eval() - - # Cache buffers for stateless/functional execution (e.g., BatchNorm running stats). - buffers = OrderedDict(self.model.named_buffers()) - - for p in params.values(): - p.requires_grad_(True) # - Define loss function for functional API - def loss_fn(p, x, y): - # Pass both params and buffers to avoid accidental stateful access. - pred = functional_call(self.model, (p, buffers), (x,)) - return self.criterion(pred, y) / self.cfg.train.grad_accumulation_steps - - # - Compute gradients - grad_fn = grad(loss_fn, argnums=0) - - # - Current task gradient - grad_curr = grad_fn(params, x_curr, y_curr) - - # - Memory task gradient - grad_mem = grad_fn(params, x_mem, y_mem) - - # - JVP computation def f(p, x): - return loss_fn(p, x, y_mem) + pred = functional_call(self.model, (p,), (x,)) + return ( + self.criterion(pred, hist_batch[1]) + / self.cfg.train.grad_accumulation_steps + ) def jvp_func(p, tangents): - return jvp(f, (p, x_mem), tangents)[1] + return jvp(f, (p, hist_batch[0]), tangents)[1] # - Use current gradient as tangent direction - tangents = OrderedDict((k, grad_curr[k]) for k in params) + tangents = OrderedDict( + (name, param.grad.detach().clone()) for name, param in params.items() + ) + + deltax = ( + self.jvp_deltax_norm + * (hist_batch[0] - batch[0]) + / (torch.linalg.norm(hist_batch[0]) + torch.linalg.norm(batch[0])) + ) + # - JVP computation grad_jvp = grad(jvp_func)(params, (tangents, deltax)) # - Combine gradients - combined_grads = { - k: grad_curr[k] + grad_mem[k] + self.jvp_lambda * grad_jvp[k] - for k in params - } - - # - Compute loss values - with torch.no_grad(): - loss_curr = loss_fn(params, x_curr, y_curr) - loss_mem = loss_fn(params, x_mem, y_mem) - - # Restore original training mode - if was_training: - self.model.train() - - return combined_grads, loss_curr, loss_mem + for n, p in self.model.named_parameters(): + if p.grad is not None: + p.grad += self.jvp_lambda * grad_jvp[n] + else: + raise KeyError( + "param ", n, " has no grad, but JVP regularizer expected one." + ) From 4fe214fe1e693beef81fc00ef708cf621e8f5f92 Mon Sep 17 00:00:00 2001 From: anagainaru Date: Fri, 3 Apr 2026 21:47:00 -0400 Subject: [PATCH 3/6] Changes to pass mypy --- src/training/updater/jvp_reg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/training/updater/jvp_reg.py b/src/training/updater/jvp_reg.py index 8da8056..94ab474 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/training/updater/jvp_reg.py @@ -71,10 +71,10 @@ def fwd_bwd( ) # fill the model gradients with the default grad info ### JVP gradients ### - - self._compute_jvp_gradients( - self._params, batch, hist_batch - ) # adds jvp gradiens to existing gradients + if self._params is not None and hist_batch is not None: + self._compute_jvp_gradients( + self._params, batch, hist_batch + ) # adds jvp gradiens to existing gradients ### History gradients #### if hist_batch is None: From c26ab75d196d961c9035fa5e6fff9fe198bf2fa7 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Fri, 3 Apr 2026 22:00:37 -0400 Subject: [PATCH 4/6] Wip --- src/training/updater/jvp_reg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/training/updater/jvp_reg.py b/src/training/updater/jvp_reg.py index 94ab474..8fcd18d 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/training/updater/jvp_reg.py @@ -110,7 +110,10 @@ def jvp_func(p, tangents): # - Use current gradient as tangent direction tangents = OrderedDict( - (name, param.grad.detach().clone()) for name, param in params.items() + (name, param.grad.detach().clone()) + for name, param in params.items() + if param.grad is not None + #(name, param.grad.detach().clone()) for name, param in params.items() ) deltax = ( From eb051e8691409bcb7f5a28cc8bcaf608ac1d46c6 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Fri, 3 Apr 2026 22:06:02 -0400 Subject: [PATCH 5/6] format --- src/training/updater/jvp_reg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/training/updater/jvp_reg.py b/src/training/updater/jvp_reg.py index 8fcd18d..0dbfcff 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/training/updater/jvp_reg.py @@ -113,7 +113,6 @@ def jvp_func(p, tangents): (name, param.grad.detach().clone()) for name, param in params.items() if param.grad is not None - #(name, param.grad.detach().clone()) for name, param in params.items() ) deltax = ( From 0090a24b5acd839ea9a0f365311b318468b40e80 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Tue, 7 Apr 2026 12:57:33 -0500 Subject: [PATCH 6/6] Using the base updater when jvp doesn't have enough history --- src/training/updater/jvp_reg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/training/updater/jvp_reg.py b/src/training/updater/jvp_reg.py index 0dbfcff..58b06dd 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/training/updater/jvp_reg.py @@ -78,7 +78,7 @@ def fwd_bwd( ### History gradients #### if hist_batch is None: - raise ValueError("JVP regularizer expects history data, but there was none") + return super().fwd_bwd(batch) x_mem, y_mem = hist_batch outputs_mem = self.model(x_mem)