Add rrules for extrema, findmax, maximum#480
Conversation
Codecov Report
@@ Coverage Diff @@
## master #480 +/- ##
==========================================
- Coverage 98.38% 98.05% -0.33%
==========================================
Files 21 22 +1
Lines 2287 2414 +127
==========================================
+ Hits 2250 2367 +117
- Misses 37 47 +10
Continue to review full report at Codecov.
|
test/rulesets/Base/array.jl
Outdated
| @testset "$findm" for findm in [findmax, findmin] | ||
| @test_skip test_rrule(findm, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error? | ||
| @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), 999), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? | ||
| end | ||
| @test rrule(findmax, [1,2,33])[1] == (33, 3) | ||
| @test rrule(findmin, [11,22,33])[1] == (11, 1) | ||
|
|
||
| @test [0,0,1] == @inferred unthunk(rrule(findmax, [1,2,3])[2]((1.0, nothing))[2]) | ||
| @test [1,0,0] == @inferred unthunk(rrule(findmin, [1,2,3])[2]((1.0, nothing))[2]) |
There was a problem hiding this comment.
Not sure why test_rrule fails here, but explicit tests work. The error is:
julia> test_rrule(findm, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false)
test_rrule: findmax on Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:191
Got exception outside of a @test
DimensionMismatch("second dimension of A, 2, does not match length of x, 1")
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:477
[2] mul!
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:87 [inlined]
[3] mul!
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:255 [inlined]
[4] *(tA::Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:80
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/aqPCI/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Tuple{Float64, NoTangent}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/aqPCI/src/grad.jl:73
[7] _make_j′vp_call(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Tuple{Float64, NoTangent}, xs::Tuple{typeof(findmax), Vector{Float64}}, ignores::Tuple{Bool, Bool})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/finite_difference_calls.jl:51
[8] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:222 [inlined]
[9] macro expansion
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/Test/src/Test.jl:1282 [inlined]
[10] test_rrule(config::ChainRulesTestUtils.ADviaRuleConfig, f::typeof(findmax), args::Vector{Float64}; output_tangent::Tuple{Float64, NoTangent}, check_thunked_output_tangent::Bool, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:194
There was a problem hiding this comment.
In revised code, this one id fixed, but errors persist for dims=1 etc. cases
|
I am on leave most of next week. |
src/rulesets/Base/array.jl
Outdated
| @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) | ||
| y, ind = $findm(x; dims=dims) | ||
| project = ProjectTo(x) | ||
| function $findm_pullback((dy, _)) |
There was a problem hiding this comment.
Should this take a Tangent instead? I think we can still dispatch on dy being an AbstractZero
There was a problem hiding this comment.
Maybe? I am a bit confused about Tangent. I was trying things out with Zygote and they appear to work, but perhaps this would still work if the signature was findm_pullback(::Tangent).
There was a problem hiding this comment.
Ah, never mind, I thought the destructuring places a constraint. It's fine this way
There was a problem hiding this comment.
Added a comment & checked
There was a problem hiding this comment.
The method below $findm_pullback(::Tuple{AbstractZero, Any}) will however not accept a Tangent.
Should it be Tangent{<:Any, <: Tuple{AbstractZero, Any}}? Or just dy isa AbstractZero && return (NoTangent(), NoTangent())?
There was a problem hiding this comment.
(Now changed to a branch, seems simplest.)
There was a problem hiding this comment.
The method below $findm_pullback(::Tuple{AbstractZero, Any}) will however not accept a Tangent.
this is fixed now, right?
There was a problem hiding this comment.
Yes, I killed that method, and just check the type after destructuring.
src/rulesets/Base/array.jl
Outdated
| function $findm_pullback((dy, _)) | ||
| x_thunk = @thunk begin | ||
| dx = fill!(similar(x, eltype(dy)), false) | ||
| view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray | ||
| project(dx) | ||
| end | ||
| x_ithunk = InplaceableThunk(x_thunk) do dx | ||
| view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 | ||
| dx | ||
| end | ||
| return (NoTangent(), x_ithunk) | ||
| end |
There was a problem hiding this comment.
could we close over size of x only here?
alternatively, I wonder whether we could reuse the rrule for getindex?
i.e. something like
| function $findm_pullback((dy, _)) | |
| x_thunk = @thunk begin | |
| dx = fill!(similar(x, eltype(dy)), false) | |
| view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray | |
| project(dx) | |
| end | |
| x_ithunk = InplaceableThunk(x_thunk) do dx | |
| view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 | |
| dx | |
| end | |
| return (NoTangent(), x_ithunk) | |
| end | |
| function $findm_pullback((dy, _)) | |
| _, getindex_back = rrule(getindex, x, ind) | |
| return getindex_back(dy)[1:2] | |
| end |
There was a problem hiding this comment.
could we close over size of x only here?
There is similar(typeof(x), size(x)) but no similar(typeof(x), T, size(x)). Of course very often ProjectTo is going to ensure that the eltype should really be the same as x's, but not quite always. It's a bit awkward.
getindex has @thunk(getindex_add!(zero(x))) which seems worse -- it is not always going to be mutable, and it won't handle structure well, e.g. zero(Diagonal([1,2,3]))[1,2] = 9.
Agree that getting it right in one place makes some sense. Zygote now has a special struct for scalar getindex, especially because using that repeatedly in a loop seems common. That does not seem so common for maximum, which could mean the weirdness doesn't pay for itself? Or maybe InplaceableStuff will make that obsolete anyway.
There was a problem hiding this comment.
I am not sure I understand the getindex comment. Isn't similar also going to make a Diagonal, just like zero does?
There was a problem hiding this comment.
Oh yea, both mess that one up, you need similar(Diagonal(rand(3)), Int, (3,3)) to get something sure to be writeable. It's zero(SA[1,2,3]) which is worse than similar.
There was a problem hiding this comment.
Or maybe InplaceableStuff will make that obsolete anyway.
Is that your PR which makes InplaceableThunk take the third argument? I like that idea
There was a problem hiding this comment.
Oh maybe that too. But I meant that Zygote.OneElement is part of trying to speed up scalar indexing in a loop. The next step is FluxML/Zygote.jl#981 . But this is a bit of a Zygote-style hack, and the eventual version involve ChainRules's in-place stuff. Then a tight loop might generate a million thunks, instead of a million OneElement "arrays".
|
Thanks, will update to use new CRCore. Check that second derivatives aren't completely wrong (with Zygote's rule disabled): |
test/rulesets/Base/array.jl
Outdated
| @test_skip test_frule(findmin, rand(3,4)) # StackOverflowError, CartesianIndex{2}(index::Tuple{Float64, Float64}) (repeats 79984 times) & TypeError: in new, expected Tuple{Int64, Int64}, got a value of type Tuple{Float64, Float64} | ||
| @test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent())) | ||
| @test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,)) | ||
|
|
||
| # Reverse | ||
| test_rrule(findmin, rand(10), output_tangent = (rand(), false)) | ||
| test_rrule(findmax, rand(10), output_tangent = (rand(), false)) | ||
| @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) | ||
| @test_skip test_rrule(findmin, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # DimensionMismatch from FiniteDifferences | ||
| @test_skip test_rrule(findmax, rand(5,3), output_tangent = (rand(), false), check_inferred=false) # DimensionMismatch from FiniteDifferences |
There was a problem hiding this comment.
Was still a bit unhappy about these tests.
There was a problem hiding this comment.
I think this would solve it: JuliaDiff/FiniteDifferences.jl#188
There was a problem hiding this comment.
Any thoughts on this? Maybe should merge without these tests.
There was a problem hiding this comment.
I would suggest keeping as it is and adding a comment that a lot of dimension mismatches would be solved by fixing the JuliaDiff/FiniteDifferences.jl#188
There was a problem hiding this comment.
FWIW the current error here is:
julia> test_frule(findmin, rand(3,4))
test_frule: findmin on Matrix{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/73Y9Q/src/testers.jl:118
Got exception outside of a @test
iteration is deliberately unsupported for CartesianIndex. Use `I` rather than `I...`, or use `Tuple(I)...`
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] iterate(#unused#::CartesianIndex{2})
@ Base.IteratorsMD ./multidimensional.jl:167
[3] copyto!(dest::Vector{Int64}, src::CartesianIndex{2})
@ Base ./abstractarray.jl:901
[4] _collect(cont::UnitRange{Int64}, itr::CartesianIndex{2}, #unused#::Base.HasEltype, isz::Base.HasLength)
@ Base ./array.jl:715
[5] collect(itr::CartesianIndex{2})
@ Base ./array.jl:709
[6] test_approx(actual::CartesianIndex{2}, expected::CartesianIndex{2}, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/73Y9Q/src/check_result.jl:141
test/rulesets/Base/array.jl
Outdated
| # Reverse with dims: | ||
| @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) | ||
| @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) | ||
| @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? |
There was a problem hiding this comment.
These could be solved by FiniteDifferences knowing that CartesianIndex is not perturbable: JuliaDiff/FiniteDifferences.jl#196
Here's a PR to fix it: JuliaDiff/FiniteDifferences.jl#197
There was a problem hiding this comment.
that's been merged, it needs FiniteDifferences 0.12.20
There was a problem hiding this comment.
That now passes, thanks!
|
Testsets pass locally, will merge when (if!) CI agrees. Besides tests, upgraded today to allow arrays of arrays. There are a few tests skipped, I think due to weird FiniteDifferences errors. But I think the rules work, and e.g. the |
Aims to address FluxML/Zygote.jl#1034 by widening the type of the array of zeros it writes into. And, while there, fixes some related functions:
Still some possible bugs in
frules, or their tests?