Configured rule for maximum(f, xs)#490
Conversation
maximum(f, xs)maximum(f, xs)
First attemptWith a more expensive function: The The broadcasted one uses dual numbers, which is much quicker. Note BTW that there is no chunk mode in play here -- it always evaluates I'm not so sure why the complete reduction is slower than broadcasting here, but it's much closer, and 3x less memory. Diffractor, BTW, does not see this rule. It does see #480, but broadcast times are variable: |
|
This has been much simplified. For the case of a complete reduction only, julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30))); # this PR + Zygote + Julia 1.8
min 8.625 μs, mean 10.906 μs (52 allocations, 8.92 KiB. GC mean 13.94%)
julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
min 10.041 μs, mean 16.087 μs (49 allocations, 36.88 KiB. GC mean 20.75%)
julia> @btime gradient(x -> sum(maximum(log∘exp, x)), $(rand(30,30))); # with a more expensive function:
min 20.208 μs, mean 22.335 μs (116 allocations, 10.88 KiB. GC mean 5.22%)
julia> @btime gradient(x -> sum(maximum((log∘exp).(x))), $(rand(30,30)));
min 19.291 μs, mean 25.757 μs (49 allocations, 36.88 KiB. GC mean 13.03%)
julia> @btime maximum(log∘exp, $(rand(30,30)));
min 8.958 μs, mean 9.128 μs (0 allocations)That means it calls Instead of using For cases with On Julia 1.6 and below, the method |
|
Status here is as in (edited) first message above. Perhaps the broadcast path can be easily tested using JuliaDiff/ChainRulesTestUtils.jl#243 once that's available. |
mzgubic
left a comment
There was a problem hiding this comment.
A few questions, generally looks good. Do you plan to extend the tests?
test/rulesets/Base/mapreduce.jl
Outdated
| @test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false) | ||
| @test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(; dims = 2), check_inferred=false) |
There was a problem hiding this comment.
I thought these needed JuliaDiff/ChainRulesTestUtils.jl#243 : with dims it always calls broadcast.
There was a problem hiding this comment.
Yep, they do need JuliaDiff/ChainRulesTestUtils.jl#243 (now merged), but also JuliaDiff/FiniteDifferences.jl#203 to get around to_vecing InplaceableThunks correctly (tested locally)
There was a problem hiding this comment.
But where do InplaceableThunks come from? This path of this rule doesn't make them.
I do still get an error with only CRTU update:
julia> test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
test_rrule: maximum on typeof(sqrt),Matrix{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/fCvaU/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch("second dimension of A, 4, does not match length of x, 7")
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
[2] mul!
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
[3] mul!
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
[4] *(tA::Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/R6uao/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#45"{ChainRulesTestUtils.var"#call#41"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{typeof(broadcast), typeof(sqrt), Matrix{Float64}}, Tuple{Bool, Bool, Bool}}, ȳ::InplaceableThunk{Thunk{ChainRules.var"#1316#1319"{Matrix{Float64}, Int64, Matrix{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, Matrix{CartesianIndex{2}}}}, ChainRules.var"#1317#1320"{Matrix{Float64}, Int64, Matrix{CartesianIndex{2}}}}, x::Matrix{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/R6uao/src/grad.jl:73
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/fCvaU/src/finite_difference_calls.jl:51
[8] f_pb
@ ~/.julia/packages/ChainRulesTestUtils/fCvaU/src/rule_config.jl:40 [inlined]
[9] (::ChainRules.var"#minormax_f_back2#2098"{ChainRules.var"#maximum_pullback#1326"{ChainRules.var"#findmax_pullback#1318"{Int64
```
There was a problem hiding this comment.
Solved by the to_vec PR, as you said.
Can this thing give less cryptic errors than these "DimensionMismatch" when it goes wrong?
There was a problem hiding this comment.
Yeah, I agree with you in general: JuliaDiff/ChainRulesTestUtils.jl#244
Here though this is coming from rrule_via_ad using the make_v'jp_call rather than the usual place 😂
Solving JuliaDiff/ChainRulesTestUtils.jl#213 would be a big QoL improvement indeed. It's on my list
There was a problem hiding this comment.
JuliaDiff/FiniteDifferences.jl#203 is now merged, so I think we can update the tests
There was a problem hiding this comment.
Great!
This one is weird locally, but on 1.6 it seems to work (or will once changed to ≈ [10 0 0; 0 -20 0]):
julia> y2, bk2 = rrule(CFG, minimum, abs, [1 2 3; -5 -4 -4], dims = 2);
julia> @test y2 == hcat([1, 4])
Test Passed
Expression: y2 == hcat([1, 4])
Evaluated: [1; 4;;] == [1; 4;;]
julia> bk2(hcat([10, 20]))
(NoTangent(), NoTangent(), NoTangent())
save less stuff in sum(f, xs) rule probably destroyed in the rebase re-organise change to use BitArray add a few tests Revert "save less stuff in sum(f, xs) rule" This reverts commit c8034da. tidy, add cumsum trick tests for multiple maxima tweaks
This uses the
RuleConfig{>:HasReverseMode}story to call back into AD to write a rule formaximum(f, xs).It's much simplified from the first attempt:
i = findmax(f, xs), and then usesrrule_via_ad(f, xs[i]).Fast case, before & after:
Before this PR,
gradient(x -> sum(maximum(sqrt, x, dims=1)), (rand(30,30)))gives an error with Zygote. After, it is the same speed as broadcasting.What doesn't seem easy now is testing the broadcast path.
First attempt
However, it only needs one such call, rather than one for every element. That means it ends up calling
fsayN^2 + 1times for a matrix (orN^2 + Nwithdims). This is much more efficient than calling it via AD allN^2times, saving the pullbacks somewhere, and calling just one. Not always faster than Zygote's current broadcasting (which uses ForwardDiff), but much less memory:If this is OK, then perhaps the
sum(f, x)rule from #441 should also consider callingfmore times. There's a commit here doing that, which cuts the memory use by quite a bit. Perhaps there are functionsffor which calling twice would be slower? Perhaps writingsum(f, x)vs.sum(f.(x))is how you emphasise that you care more about memory?(It may make sense to remove this & discuss[Now removed here.]sumin another thread.)All WIP, needs more careful testing, etc.