diff --git a/src/training/updater/jvp_reg.py b/src/training/updater/jvp_reg.py index 5c4f19b..58b06dd 100644 --- a/src/training/updater/jvp_reg.py +++ b/src/training/updater/jvp_reg.py @@ -64,81 +64,71 @@ def fwd_bwd( loss_curr: Loss on current task loss_mem: Loss on memory task """ - if hist_batch is None: - return super().fwd_bwd(batch, hist_batch) - x_curr, y_curr = batch - x_mem, y_mem = hist_batch - # - Compute deltax direction - deltax = ( - self.jvp_deltax_norm - * (x_mem - x_curr) - / (torch.linalg.norm(x_mem) + torch.linalg.norm(x_curr)) - ) + # Current gradients ### + loss_cur = super().fwd_bwd( + batch, hist_batch + ) # fill the model gradients with the default grad info - # - Compute combined gradients - assert self._params is not None - grad_dict, loss_curr, loss_mem = self._compute_jvp_gradients( - self._params, x_curr, y_curr, x_mem, y_mem, deltax - ) + ### JVP gradients ### + if self._params is not None and hist_batch is not None: + self._compute_jvp_gradients( + self._params, batch, hist_batch + ) # adds jvp gradiens to existing gradients - self.grad_dict = ( - {k: self.grad_dict[k] + grad_dict[k] for k in grad_dict} - if self.grad_dict is not None - else grad_dict - ) + ### History gradients #### + if hist_batch is None: + return super().fwd_bwd(batch) + x_mem, y_mem = hist_batch + outputs_mem = self.model(x_mem) + loss_mem = self.criterion(outputs_mem, y_mem) + loss_mem = loss_mem / self.cfg.train.grad_accumulation_steps + loss_mem.backward() # adds the history gradients to the default gradients inplace. self.loss_mem += loss_mem.item() - return loss_curr.item() + + return loss_cur def _compute_jvp_gradients( self, params: OrderedDict[str, torch.nn.Parameter], - x_curr: torch.Tensor, - y_curr: torch.Tensor, - x_mem: torch.Tensor, - y_mem: torch.Tensor, - deltax: torch.Tensor, - ) -> tuple[dict[str, torch.Tensor], torch.Tensor, torch.Tensor]: + batch: tuple[torch.Tensor, torch.Tensor], + hist_batch: tuple[torch.Tensor, torch.Tensor], + ) -> None: """Compute JVP-regularized gradients combining current, memory, and JVP terms.""" - for p in params.values(): - p.requires_grad_(True) # - Define loss function for functional API - def loss_fn(p, x, y): - pred = functional_call(self.model, p, (x,)) - return self.criterion(pred, y) / self.cfg.train.grad_accumulation_steps - - # - Compute gradients - grad_fn = grad(loss_fn, argnums=0) - - # - Current task gradient - grad_curr = grad_fn(params, x_curr, y_curr) - - # - Memory task gradient - grad_mem = grad_fn(params, x_mem, y_mem) - - # - JVP computation def f(p, x): - return loss_fn(p, x, y_mem) + pred = functional_call(self.model, (p,), (x,)) + return ( + self.criterion(pred, hist_batch[1]) + / self.cfg.train.grad_accumulation_steps + ) def jvp_func(p, tangents): - return jvp(f, (p, x_mem), tangents)[1] + return jvp(f, (p, hist_batch[0]), tangents)[1] # - Use current gradient as tangent direction - tangents = OrderedDict((k, grad_curr[k]) for k in params) + tangents = OrderedDict( + (name, param.grad.detach().clone()) + for name, param in params.items() + if param.grad is not None + ) + + deltax = ( + self.jvp_deltax_norm + * (hist_batch[0] - batch[0]) + / (torch.linalg.norm(hist_batch[0]) + torch.linalg.norm(batch[0])) + ) + # - JVP computation grad_jvp = grad(jvp_func)(params, (tangents, deltax)) # - Combine gradients - combined_grads = { - k: grad_curr[k] + grad_mem[k] + self.jvp_lambda * grad_jvp[k] - for k in params - } - - # - Compute loss values - with torch.no_grad(): - loss_curr = loss_fn(params, x_curr, y_curr) - loss_mem = loss_fn(params, x_mem, y_mem) - - return combined_grads, loss_curr, loss_mem + for n, p in self.model.named_parameters(): + if p.grad is not None: + p.grad += self.jvp_lambda * grad_jvp[n] + else: + raise KeyError( + "param ", n, " has no grad, but JVP regularizer expected one." + )