diff --git a/came_pytorch/CAME.py b/came_pytorch/CAME.py index 280edfd..8bb3ff1 100644 --- a/came_pytorch/CAME.py +++ b/came_pytorch/CAME.py @@ -109,10 +109,7 @@ def step(self, closure=None): else: state["exp_avg_sq"] = torch.zeros_like(grad) - state["RMS"] = 0 - state["step"] += 1 - state["RMS"] = self._rms(p.data) update = (grad**2) + group["eps"][0] if factored: @@ -171,4 +168,4 @@ def step(self, closure=None): update.mul_(group["lr"]) p.data.add_(-update) - return loss \ No newline at end of file + return loss