Add gradients for conv_bias_act, and a similar dense_bias_act#346
Add gradients for conv_bias_act, and a similar dense_bias_act#346mcabbott wants to merge 15 commits intoFluxML:masterfrom
conv_bias_act, and a similar dense_bias_act#346Conversation
This comment was marked as off-topic.
This comment was marked as off-topic.
bcb3460 to
964dc16
Compare
|
Is the original message up top still accurate? It looks like the implementation is there. What help is necessary to get this through? |
|
My memory is that this basically worked, but the performance was disappointing due to JuliaLang/julia#43153 . Writing back into the same Edit: ok I've updated things. I think the most honest benchmark looks like this, and shows a serious improvement from julia> w, b = rand(Float32, 100, 100), rand(Float32, 100); x = rand(Float32, size(w)...);
julia> @btime gradient((w,x,b) -> sum(abs2, dense_bias_act(tanh, w, x, b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 44.792 μs, mean 79.901 μs (71 allocations, 198.37 KiB)
julia> @btime gradient((w,x,b) -> sum(abs2, tanh.((w * x) .+ b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 114.583 μs, mean 158.989 μs (39 allocations, 275.25 KiB)
julia> @btime gradient((w,x,b) -> sum(abs2, tanh_fast.((w * x) .+ b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 40.125 μs, mean 75.140 μs (39 allocations, 275.25 KiB)Would be worthwhile to benchmark on other computers. (This is M1 + apple's blas.) And on GPUs. And |
|
Rebased at https://github.com/mcabbott/NNlib.jl/tree/bias_act_22 after squashing, but its own tests fail. |
This aims to add gradient definitions for the existing
conv_bias_act. That is, however, very much WIP, and I don't recommend anyone try to read it just yet.It also adds an analogous
dense_bias_act, which is closer to done. What this gains you overσ.(w*x .+ b)is memory savings. Zygote will by default un-fuse the broadcast, allocating 3 arrays on the forward pass, but in fact we can often over-write the result ofw*x, saving 2 copies. This should happen both on CPU and GPU. There is one more copy you could save on the reverse pass, bringing you to 1/2 the memory usage of before, but only if you were sure that the pullback would only be called once. That isn't true for sayZygote.jacobian, and I don't think there's a way to know when it will be safe. So we save 1/3 not 1/2, when inside Zygote.I say "often" because over-writing
w*xonly works when the gradient ofσcan be written in terms of its output, without saving its input. That's true fortanhandreluand some others, which areexplicitly whitelisted here asnow handled using JuliaDiff/ChainRulesCore.jl#453 .INPLACE_ACTS. Surely a more extensible method for that could be invented.This was written before seeing FluxML/NNlibCPU.jl#1 . But they may work well together -- for instance the function
dense!there could (after we adjust signatures a little) simply overload a function here, providing a fast path when that package is loaded. Likewise it can overloadconv_bias_act!to run a fused activation-and-convolution on the CPU, a bit like the existing NNlibCUDA routine. (From a first glance it looks likedense!has a trait for deciding which functions are in-place-safe, which is good.) Again, not fully baked, but opened now to start discussing.