Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 47 additions & 57 deletions src/training/updater/jvp_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,81 +64,71 @@ def fwd_bwd(
loss_curr: Loss on current task
loss_mem: Loss on memory task
"""
if hist_batch is None:
return super().fwd_bwd(batch, hist_batch)
x_curr, y_curr = batch
x_mem, y_mem = hist_batch

# - Compute deltax direction
deltax = (
self.jvp_deltax_norm
* (x_mem - x_curr)
/ (torch.linalg.norm(x_mem) + torch.linalg.norm(x_curr))
)
# Current gradients ###
loss_cur = super().fwd_bwd(
batch, hist_batch
) # fill the model gradients with the default grad info

# - Compute combined gradients
assert self._params is not None
grad_dict, loss_curr, loss_mem = self._compute_jvp_gradients(
self._params, x_curr, y_curr, x_mem, y_mem, deltax
)
### JVP gradients ###
if self._params is not None and hist_batch is not None:
self._compute_jvp_gradients(
self._params, batch, hist_batch
) # adds jvp gradiens to existing gradients

self.grad_dict = (
{k: self.grad_dict[k] + grad_dict[k] for k in grad_dict}
if self.grad_dict is not None
else grad_dict
)
### History gradients ####
if hist_batch is None:
return super().fwd_bwd(batch)
x_mem, y_mem = hist_batch

outputs_mem = self.model(x_mem)
loss_mem = self.criterion(outputs_mem, y_mem)
loss_mem = loss_mem / self.cfg.train.grad_accumulation_steps
loss_mem.backward() # adds the history gradients to the default gradients inplace.
self.loss_mem += loss_mem.item()
return loss_curr.item()

return loss_cur

def _compute_jvp_gradients(
self,
params: OrderedDict[str, torch.nn.Parameter],
x_curr: torch.Tensor,
y_curr: torch.Tensor,
x_mem: torch.Tensor,
y_mem: torch.Tensor,
deltax: torch.Tensor,
) -> tuple[dict[str, torch.Tensor], torch.Tensor, torch.Tensor]:
batch: tuple[torch.Tensor, torch.Tensor],
hist_batch: tuple[torch.Tensor, torch.Tensor],
) -> None:
"""Compute JVP-regularized gradients combining current, memory, and JVP terms."""
for p in params.values():
p.requires_grad_(True)

# - Define loss function for functional API
def loss_fn(p, x, y):
pred = functional_call(self.model, p, (x,))
return self.criterion(pred, y) / self.cfg.train.grad_accumulation_steps

# - Compute gradients
grad_fn = grad(loss_fn, argnums=0)

# - Current task gradient
grad_curr = grad_fn(params, x_curr, y_curr)

# - Memory task gradient
grad_mem = grad_fn(params, x_mem, y_mem)

# - JVP computation
def f(p, x):
return loss_fn(p, x, y_mem)
pred = functional_call(self.model, (p,), (x,))
return (
self.criterion(pred, hist_batch[1])
/ self.cfg.train.grad_accumulation_steps
)

def jvp_func(p, tangents):
return jvp(f, (p, x_mem), tangents)[1]
return jvp(f, (p, hist_batch[0]), tangents)[1]

# - Use current gradient as tangent direction
tangents = OrderedDict((k, grad_curr[k]) for k in params)
tangents = OrderedDict(
(name, param.grad.detach().clone())
for name, param in params.items()
if param.grad is not None
)

deltax = (
self.jvp_deltax_norm
* (hist_batch[0] - batch[0])
/ (torch.linalg.norm(hist_batch[0]) + torch.linalg.norm(batch[0]))
)

# - JVP computation
grad_jvp = grad(jvp_func)(params, (tangents, deltax))

# - Combine gradients
combined_grads = {
k: grad_curr[k] + grad_mem[k] + self.jvp_lambda * grad_jvp[k]
for k in params
}

# - Compute loss values
with torch.no_grad():
loss_curr = loss_fn(params, x_curr, y_curr)
loss_mem = loss_fn(params, x_mem, y_mem)

return combined_grads, loss_curr, loss_mem
for n, p in self.model.named_parameters():
if p.grad is not None:
p.grad += self.jvp_lambda * grad_jvp[n]
else:
raise KeyError(
"param ", n, " has no grad, but JVP regularizer expected one."
)
Loading