Use linear-indexing broadcast kernel when possible#520
Conversation
|
This looks to be working well, so tagging people who ran into issues before: @ToucheSir and @chengchingwen. Note that this will still cause additional compilation, i.e. every time the size of any container involved in a broadcast changes, but I'm curious about which workloads would trigger that (once in the steady-state application regime, of course). |
fcc80ce to
90fb573
Compare
|
I had a look back through the CI failure on the Flux side. Apparently this call was the one that failed: But that's strange, because surely broadcasting + was already tested by GPUArrays + CUDA.jl? Anyhow, I doubt this will cause any problems for FluxML as long as elementwise broadcasting of binary ops still work across the board. |
|
Looking back at the CUDA.jl CI logs, there seemed to be some issue with printing too, is why I added a |
|
I tried testing Transformers.jl, but that seems not possible right now (see chengchingwen/Transformers.jl#153 and linked PRs in NeuralAttentionlib.jl). |
|
One alternative would be that we expose 1d/2d/3d indices and only generate 4 broadcast kernels. I'll experiment with that, as it would lead to far fewer kernels being compiled (but the fact that the bounds aren't fully statically known may come at a cost again). Given #451 the above would also mean that KA.jl would need to support 1d/2d/3d indices, so cc @vchuravy. |
|
... or, I should probably just confine this optimization to Metal.jl... |
This, unfortunately, happens a lot when doing sequence generation inference with transformer models. It might also happen during training but can be avoided with padding. |
|
OK, good to know. I have an alternative in JuliaGPU/Metal.jl#304, relying on hadware indices instead. That will only accelerate 2d and 3d broadcasts though, so it's a trade-off. |
|
I think we might be able to port the algorithms used in libdivide to implement a new |
|
I only noticed significant impact of the |
Attempt to re-land #454, this time using a slightly nicer implementation.
It hasn't fundamentally changed though, so should run into the same issues. Let's do this carefully.
The motivation is also unchanged: on certain platforms, like Metal.jl, the integer divisions required to go from a linear hardware index to a cartesian one for indexing the input/output containers is extremely expensive. By using static iteration bounds, the compiler can replace the
idivwith a series of bitshifts. This improves the performance of broadcast by 3-4x on those platforms.cc @maxwindiff