Conversation
| return val; | ||
| } | ||
|
|
||
| __device__ __inline__ __e4m3 operator|(const __e4m3 x, const __e4m3 y) { |
There was a problem hiding this comment.
Not exactly sure if this is reasonable. I don't know why memcpy is extensively used for these fp8 types.
Pinging @jjsjann123
There was a problem hiding this comment.
This does look cleaner :)
There was a problem hiding this comment.
I thought there may be some magic with memcpy.
|
!build |
jacobhinkle
left a comment
There was a problem hiding this comment.
LGTM. I wonder if we could handle all types (other than ComplexDouble) by bitcasting to an integer of the right size then doing bitwise or. For this we would need Int8 and Int16, but we would not need support in the runtime files for bitwise ops on floats then.
Yeah, I did think about it, but I found it's just easier to handle these low precision types separately than adding the new integer types, at least for now. |
catis translated to CUDA code with a if-then-else block:This is correct, but I believe this can be simplified to just
out[idx] = input0[idx] + input1[idx] + ...since all of the inputs are padded by zero, so this result should be equivalent.Since
+is not defined for some low precision types, bitwise-or is instead used when addition is not available.On A100, this simplification yielded about 5% perf improvement of a RoPE module.
I didn't add any specific test since I don't know if any new test would be beneficial. Half, bfloat16 and fp8 types are tested by
ResizeTest.CatMemoryPromotionReducedFloating.