-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Support bias correction in Adam and AdamW optimizers #1640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support bias correction in Adam and AdamW optimizers #1640
Conversation
|
yes thanks for adding this option this is indeed impactful |
angeloskath
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this. Several people have requested this so I think it is time to add it.
I left some comments regarding the implementation. I think it will be both cleaner and have the benefit of no extra cost in the case where we don't use mx.compile. Let me know what you think.
python/mlx/optimizers/optimizers.py
Outdated
| return parameter - lr * m / (mx.sqrt(v) + eps) | ||
| return parameter - step_size * m / ( | ||
| mx.sqrt(v) / bias_correction2_sqrt + eps | ||
| ).astype(gradient.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you instead write this with a simple if? It would be faster in case someone is not compiling their step function. Namely something like the following:
if bias_correction:
numerator = lr / (1 - b1**step) * m
denominator = mx.sqrt(v) / (1 - b2**step) + eps
return parameter - numerator / denominator
else:
return parameter - lr * m / (mx.sqrt(v) + eps)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense; done.
python/mlx/optimizers/optimizers.py
Outdated
| Our Adam implementation follows the original paper and omits the bias | ||
| correction in the first and second moment estimates. In detail, | ||
| correction in the first and second moment estimates by default. In detail, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would simply remove that comment and document the bias_correction argument below in the args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've adjusted the comment accordingly. Does this look reasonable to you?
python/mlx/optimizers/optimizers.py
Outdated
| correction in the first and second moments for AdamW. We update the weights | ||
| with a weight_decay (:math:`\lambda`) value: | ||
| correction in the first and second moments for AdamW by default. We update | ||
| the weights with a weight_decay (:math:`\lambda`) value: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as for Adam.
angeloskath
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
|
@mt-caret Could you skip the test if there is no PyTorch available? Then I can merge thanks! |
|
I should've looked more carefully at the other tests which use torch, my bad. Fixed! |
|
...and also applied the pre-commit thing 😅 |
Proposed changes
The original implementation of AdamW doesn't include bias correction (#72). I found this causes problems when using it to learn a trivial task such as memorization using a GPT2-like architecture whereas the equivalent pytorch implementation doesn't exhibit this behavior; the changes in this PR resolve the issue.
I've manually confirmed that this matches Pytorch behavior up to some small floating point differences, and also replicated a simpler version of it in mlx tests which breaks for main but passes in my branch.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes