-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TIR] cast disparate floating point types for binary ops #8517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] cast disparate floating point types for binary ops #8517
Conversation
|
Thanks for the review @comaniac PTAL |
|
By the way, please don't merge this anytime soon. I want to get some input from some others like @jroesch or @junrushao1994 |
comaniac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Leave it to @jroesch @junrushao1994
|
As long as we don't implicitly downcast I think this should be fine, can you test behavior of storing both an i32/u32 computation in i16/u16, and the other way around? |
This will fail since this PR only upcasts at most one arg. in binary operations only (e.g. addition). Do you want me to also upcast assignment for floating point types to be consistent? @jroesch |
I meant we should add negative tests so that if someone later comes along to modify the behavior we have clearly written down what should pass and what should fail. |
Oh ok. Done. |
|
PTAL @jroesch |
|
LGTM, just get it green and we are gtg |
|
Thanks @AndrewZhaoLuo @jroesch |
* handle upcasting case * test upcasting tests for tir * address comaniac comments * formatting * add negative tests * fix failing test now allow other things Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
* handle upcasting case * test upcasting tests for tir * address comaniac comments * formatting * add negative tests * fix failing test now allow other things Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
Right now if we in TIR add a floating point type with an integer type, the integer type will be implicitly cast to the floating point type.
E.g.
a: float32 + b: int32 ---> a: float32 + cast(float32, b: int32)This change does the same thing for floating point types. If we have two different floating point types e.g. fp16 and fp32, then when operating, the lower bit floating point type gets cast into the higher bit floating point type.
E.g.
a: float32 + b: float16 --> a: float32 + cast(float32, b: float16)This is of use since #8340 has an issue where some schedules which should support mixed precision types do not. This is due to binary ops like addition and multiplication not supporting mixing fp32 and fp16. Most of these errors can be fixed by inserting a cast into the schedule.
Rather than manually audit every schedule which might have this, this might be a preferable and reasonable solution.