Skip to content

Conversation

@mt-caret
Copy link
Contributor

@mt-caret mt-caret commented Dec 3, 2024

Proposed changes

The original implementation of AdamW doesn't include bias correction (#72). I found this causes problems when using it to learn a trivial task such as memorization using a GPT2-like architecture whereas the equivalent pytorch implementation doesn't exhibit this behavior; the changes in this PR resolve the issue.

I've manually confirmed that this matches Pytorch behavior up to some small floating point differences, and also replicated a simpler version of it in mlx tests which breaks for main but passes in my branch.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@guillaume-osmo
Copy link

yes thanks for adding this option this is indeed impactful

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this. Several people have requested this so I think it is time to add it.

I left some comments regarding the implementation. I think it will be both cleaner and have the benefit of no extra cost in the case where we don't use mx.compile. Let me know what you think.

return parameter - lr * m / (mx.sqrt(v) + eps)
return parameter - step_size * m / (
mx.sqrt(v) / bias_correction2_sqrt + eps
).astype(gradient.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you instead write this with a simple if? It would be faster in case someone is not compiling their step function. Namely something like the following:

if bias_correction:
    numerator = lr / (1 - b1**step) * m
    denominator = mx.sqrt(v) / (1 - b2**step) + eps
    return parameter - numerator / denominator
else:
    return parameter - lr * m / (mx.sqrt(v) + eps)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense; done.

Our Adam implementation follows the original paper and omits the bias
correction in the first and second moment estimates. In detail,
correction in the first and second moment estimates by default. In detail,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would simply remove that comment and document the bias_correction argument below in the args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've adjusted the comment accordingly. Does this look reasonable to you?

correction in the first and second moments for AdamW. We update the weights
with a weight_decay (:math:`\lambda`) value:
correction in the first and second moments for AdamW by default. We update
the weights with a weight_decay (:math:`\lambda`) value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for Adam.

@mt-caret mt-caret requested a review from angeloskath December 4, 2024 05:48
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@angeloskath
Copy link
Member

@mt-caret Could you skip the test if there is no PyTorch available? Then I can merge thanks!

@mt-caret
Copy link
Contributor Author

mt-caret commented Dec 5, 2024

I should've looked more carefully at the other tests which use torch, my bad. Fixed!

@mt-caret
Copy link
Contributor Author

mt-caret commented Dec 6, 2024

...and also applied the pre-commit thing 😅

@angeloskath angeloskath merged commit fd3377d into ml-explore:main Dec 6, 2024
5 checks passed
@mt-caret mt-caret deleted the add-bias-correction-to-adam branch December 7, 2024 01:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants