diff --git a/Project.toml b/Project.toml index 1274201c2..b36f894cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.13.0" +version = "1.14.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,10 +11,10 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "1.10" +ChainRulesCore = "1.11" ChainRulesTestUtils = "1" Compat = "3.35" -FiniteDifferences = "0.12.8" +FiniteDifferences = "0.12.20" JuliaInterpreter = "0.8" RealDot = "0.1" StaticArrays = "1.2" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 184a833f0..63ce70ae0 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -342,3 +342,165 @@ function rrule(::typeof(fill), x::Any, dims...) fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...) return fill(x, dims...), fill_pullback end + +##### +##### `findmax`, `maximum`, etc. +##### + +for findm in (:findmin, :findmax) + findm_pullback = Symbol(findm, :_pullback) + + @eval function frule((_, xdot), ::typeof($findm), x; dims=:) + y, ind = $findm(x; dims=dims) + return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent()) + end + + @eval function rrule(::typeof($findm), x::AbstractArray; dims=:) + y, ind = $findm(x; dims=dims) + project = ProjectTo(x) + # This pullback is a lot like the one for getindex. Ideally they would probably be combined? + function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) + dy isa AbstractZero && return (NoTangent(), NoTangent()) + x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind)) + x_ithunk = InplaceableThunk(x_thunk) do dx + if dims isa Colon + view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy)) + else + view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0 + end + dx + end + return (NoTangent(), x_ithunk) + end + return (y, ind), $findm_pullback + end +end + +# This function is roughly `setindex!(zero(x), dy, inds...)`: + +function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...) + # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't + # allow `eltype(dy)`, nor does it work for many structured matrices. + dx = fill!(similar(x, eltype(dy), axes(x)), 0) + view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray + dx +end +function _zerolike_writeat(x::AbstractArray, dy, dims, inds...) + # Since we have `x`, we can also handle arrays of arrays. + dx = map(zero, x) + if dims isa Colon + view(dx, inds...) .= Ref(dy) + else + view(dx, inds...) .= dy + end + dx +end + +# Allow for second derivatives, by writing rules for `_zerolike_writeat`; +# these rules are the reason it takes a `dims` argument. + +function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...) + return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...) +end + +function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...) + z = _zerolike_writeat(x, dy, dims, inds...) + function _zerolike_writeat_pullback(dz) + dx = sum(view(unthunk(dz), inds...); dims=dims) + nots = map(_ -> NoTangent(), inds) + return (NoTangent(), NoTangent(), dx, NoTangent(), nots...) + end + return z, _zerolike_writeat_pullback +end + +# These rules for `maximum` pick the same subgradient as `findmax`: + +function frule((_, xdot), ::typeof(maximum), x; dims=:) + y, ind = findmax(x; dims=dims) + return y, xdot[ind] +end + +function rrule(::typeof(maximum), x::AbstractArray; dims=:) + (y, _), back = rrule(findmax, x; dims=dims) + maximum_pullback(dy) = back((dy, nothing)) + return y, maximum_pullback +end + +function frule((_, xdot), ::typeof(minimum), x; dims=:) + y, ind = findmin(x; dims=dims) + return y, xdot[ind] +end + +function rrule(::typeof(minimum), x::AbstractArray; dims=:) + (y, _), back = rrule(findmin, x; dims=dims) + minimum_pullback(dy) = back((dy, nothing)) + return y, minimum_pullback +end + +##### +##### `extrema` +##### + +function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) + if dims isa Colon + return _extrema_colon(x) + else + return _extrema_dims(x, dims) + end +end + +function _extrema_colon(x) + ylo, ilo = findmin(x) + yhi, ihi = findmax(x) + project = ProjectTo(x) + function extrema_pullback((dylo, dyhi)) # accepts Tangent + if (dylo, dyhi) isa Tuple{AbstractZero, AbstractZero} + return (NoTangent(), NoTangent()) + end + # One argument may be AbstractZero here. Use promote_op because + # promote_type allows for * as well as +, hence gives Any. + T = Base.promote_op(+, typeof(dylo), typeof(dyhi)) + x_nothunk = let + # x_thunk = @thunk begin # this doesn't infer + dx = fill!(similar(x, T, axes(x)), false) + view(dx, ilo) .= dylo + view(dx, ihi) .= view(dx, ihi) .+ dyhi + project(dx) + end + # x_ithunk = InplaceableThunk(x_thunk) do dx + # view(dx, ilo) .= view(dx, ilo) .+ dylo + # view(dx, ihi) .= view(dx, ihi) .+ dyhi + # dx + # end + return (NoTangent(), x_nothunk) + end + return (ylo, yhi), extrema_pullback +end + +function _extrema_dims(x, dims) + ylo, ilo = findmin(x; dims=dims) + yhi, ihi = findmax(x; dims=dims) + y = similar(ylo, Tuple{eltype(ylo), eltype(yhi)}) + map!(tuple, y, ylo, yhi) # this is a GPU-friendly version of collect(zip(ylo, yhi)) + project = ProjectTo(x) + function extrema_pullback_dims(dy_raw) + dy = unthunk(dy_raw) + @assert dy isa AbstractArray{<:Tuple{Any,Any}} + # Can we actually get Array{Tuple{Float64,ZeroTangent}} here? Not sure. + T = Base.promote_op(+, eltype(dy).parameters...) + x_nothunk = let + # x_thunk = @thunk begin # this doesn't infer + dx = fill!(similar(x, T, axes(x)), false) + view(dx, ilo) .= first.(dy) + view(dx, ihi) .= view(dx, ihi) .+ last.(dy) + project(dx) + end + # x_ithunk = InplaceableThunk(x_thunk) do dx + # view(dx, ilo) .= first.(dy) + # view(dx, ihi) .= view(dx, ihi) .+ last.(dy) + # dx + # end + return (NoTangent(), x_nothunk) + end + return y, extrema_pullback_dims +end diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 0412ac451..dad5533f7 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -96,6 +96,8 @@ @non_differentiable all(::Any, ::Any) @non_differentiable any(::Any) @non_differentiable any(::Any, ::Any) +@non_differentiable argmax(::Any) +@non_differentiable argmin(::Any) @non_differentiable ascii(::AbstractString) @non_differentiable axes(::Any) @non_differentiable axes(::Any, ::Any) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index f601a0bb8..ff3e93931 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -198,3 +198,79 @@ end test_rrule(fill, 55 + 0.5im, 5) test_rrule(fill, 3.3, (3, 3, 3)) end + +@testset "findmin & findmax" begin + # Forward + test_frule(findmin, rand(10)) + test_frule(findmax, rand(10)) + @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4))) isa Tuple{Tuple{Float64, CartesianIndex}, Tangent} + @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4), dims=1)) isa Tuple{Tuple{Matrix, Matrix}, Tangent} + @test_skip test_frule(findmin, rand(3,4)) # error from test_approx(actual::CartesianIndex{2}, expected::CartesianIndex{2} + @test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent())) + @test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,)) + # These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 + + # Reverse + test_rrule(findmin, rand(10), output_tangent = (rand(), false)) + test_rrule(findmax, rand(10), output_tangent = (rand(), false)) + test_rrule(findmin, rand(5,3)) + test_rrule(findmax, rand(5,3)) + @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) + @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2]) + + # Reverse with dims: + @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) + @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) + test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent())) + test_rrule(findmin, rand(3,4), fkwargs=(dims=2,)) + + # Second derivatives + test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) + test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) + @test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9) + y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)]) + @test y == [0 0; 5 5] + @test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent()) +end + +@testset "$imum" for imum in [maximum, minimum] + # Forward + test_frule(imum, rand(10)) + test_frule(imum, rand(3,4)) + test_frule(imum, rand(3,4), fkwargs=(dims=1,)) + test_frule(imum, [rand(2) for _ in 1:3]) + test_frule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,)) + + # Reverse + test_rrule(imum, rand(10)) + test_rrule(imum, rand(3,4)) + test_rrule(imum, rand(3,4), fkwargs=(dims=1,)) + test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),)) + + # Arrays of arrays + test_rrule(imum, [rand(2) for _ in 1:3]; check_inferred=false) + test_rrule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,), check_inferred=false) + + # Case which attains max twice -- can't use FiniteDifferences for this + res = imum == maximum ? [0,1,0,0,0,0] : [1,0,0,0,0,0] + @test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2]) + + # Structured matrix -- NB the minimum is a structral zero here + @test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal + @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64} + @test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool) +end + +@testset "extrema" begin + test_rrule(extrema, rand(10), output_tangent = (rand(), rand())) + test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = collect(zip(rand(1,4), rand(1,4)))) + # Case where both extrema are the same index, to check accumulation: + test_rrule(extrema, rand(1), output_tangent = (rand(), rand())) + test_rrule(extrema, rand(1,1), fkwargs=(dims=2,), output_tangent = hcat((rand(), rand()))) + test_rrule(extrema, rand(3,1), fkwargs=(dims=2,), output_tangent = collect(zip(rand(3,1), rand(3,1)))) + # Double-check the forward pass + A = randn(3,4,5) + @test extrema(A, dims=(1,3)) == rrule(extrema, A, dims=(1,3))[1] + B = hcat(A[:,:,1], A[:,:,1]) + @test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1] +end