Skip to content

Flax dtype-dependent numerical masking#21197

Merged
gante merged 4 commits intohuggingface:mainfrom
gante:flax_opt_batch
Jan 19, 2023
Merged

Flax dtype-dependent numerical masking#21197
gante merged 4 commits intohuggingface:mainfrom
gante:flax_opt_batch

Conversation

@gante
Copy link
Copy Markdown
Contributor

@gante gante commented Jan 19, 2023

What does this PR do?

Fixes #21176

For some models, our Flax numerical masking was incompatible with the desired variable type. This PR fixes it by selecting a numerical mask that is the minimum for the corresponding variable type.

This PR is akin to #17306 for PT. Thank you @LysandreJik and @ydshieh for pointing it out 🙏

@gante gante requested a review from sgugger January 19, 2023 16:10
@gante
Copy link
Copy Markdown
Contributor Author

gante commented Jan 19, 2023

@sgugger this solution was discussed on Slack with the Flax team, hence no added Flax reviewers :)

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!, Thanks for the fix!

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jan 19, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante gante merged commit cbaaa2f into huggingface:main Jan 19, 2023
@gante gante deleted the flax_opt_batch branch January 19, 2023 16:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FlaxGPTNeoForCausalLM not working properly with fp16 when using left padding.

3 participants