From fc6fad91a8d75b05a377852263d2d2e0117c3f45 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 06:41:06 -0400 Subject: [PATCH 01/30] rules for extrema, findmax, maximum --- src/rulesets/Base/array.jl | 73 +++++++++++++++++++++++++++++++++++++ test/rulesets/Base/array.jl | 12 ++++++ 2 files changed, 85 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 184a833f0..78545b6f1 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -342,3 +342,76 @@ function rrule(::typeof(fill), x::Any, dims...) fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...) return fill(x, dims...), fill_pullback end + +##### +##### `extrema`, `findmax`, `maximum`, etc. +##### + +function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) + yplus, iplus = findmax(x; dims=dims) + yminus, imminus = findmin(x; dims=dims) + project = ProjectTo(x) + function extrema_pullback((dyplus, dyminus)) + x_thunk = @thunk begin + dx = fill!(similar(x, eltype(dy)), false) + view(dx, iplus) .= dyplus + view(dx, iminus) .+= dyminus + project(dx) + end + x_ithunk = InplaceableThunk(x_thunk) do dx + view(dx, iplus) .+= dyplus + view(dx, iminus) .+= dyminus + dx + end + return (NoTangent(), x_ithunk) + end + return (yplus, yminus), extrema_pullback +end + +function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) + y, ind = findmax(x; dims=dims) + project = ProjectTo(x) + function findmax_pullback((dy, _)) + x_thunk = @thunk begin + dx = fill!(similar(x, eltype(dy)), false) + view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array cases + project(dx) + end + x_ithunk = InplaceableThunk(x_thunk) do dx + view(dx, ind) .+= dy + dx + end + return (NoTangent(), x_ithunk) + end + return (y, ind), findmax_pullback +end + +function rrule(::typeof(findmin), x::AbstractArray{<:Number}; dims=:) + y, ind = findmin(x; dims=dims) + project = ProjectTo(x) + function findmin_pullback((dy, _)) + x_thunk = @thunk begin + dx = fill!(similar(x, eltype(dy)), false) + view(dx, ind) .= dy + project(dx) + end + x_ithunk = InplaceableThunk(x_thunk) do dx + view(dx, ind) .+= dy + dx + end + return (NoTangent(), x_ithunk) + end + return (y, ind), findmin_pullback +end + +function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) + (y, _), back = rrule(findmax, x; dims=dims) + maximum_pullback(dy) = back((dy, nothing)) + return y, maximum_pullback +end + +function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) + (y, _), back = rrule(findmin, x; dims=dims) + minimum_pullback(dy) = back((dy, nothing)) + return y, minimum_pullback +end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index f601a0bb8..4fccadb84 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -198,3 +198,15 @@ end test_rrule(fill, 55 + 0.5im, 5) test_rrule(fill, 3.3, (3, 3, 3)) end + +@testset "extrema" begin + @testset "$f" for f in [maximum, minimum] + test_rrule(f, rand(10)) + test_rrule(f, rand(3,4)) + test_rrule(f, rand(3,4), fkwargs=(dims=1,)) + test_rrule(f, rand(3,4,5), fkwargs=(dims=(1,3),)) + test_rrule(f, rand(1)) # both extrema are the same index + @test_skip test_rrule(f, Float64[1,2,-1,-2,0,2,-2]) # attains max twice -- finite diff picks another subgradient + end +end + From 0e01222b1bf635cd90d4218704c96e7d1dc013be Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 23:22:59 -0400 Subject: [PATCH 02/30] fixup extrema --- src/rulesets/Base/array.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 78545b6f1..5356152e6 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -348,24 +348,26 @@ end ##### function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) - yplus, iplus = findmax(x; dims=dims) - yminus, imminus = findmin(x; dims=dims) + ylo, ilo = findmin(x; dims=dims) + yhi, ihi = findmax(x; dims=dims) project = ProjectTo(x) - function extrema_pullback((dyplus, dyminus)) + function extrema_pullback((dylo, dyhi)) + T = promote_type(eltype(dylo), eltype(dyhi)) + # @show T # often Any, when dyhi == NoTangent() x_thunk = @thunk begin - dx = fill!(similar(x, eltype(dy)), false) - view(dx, iplus) .= dyplus - view(dx, iminus) .+= dyminus + dx = fill!(similar(x, T), false) + view(dx, ilo) .+= dylo + view(dx, ihi) .+= dyhi project(dx) end x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, iplus) .+= dyplus - view(dx, iminus) .+= dyminus + view(dx, ilo) .+= dylo + view(dx, ihi) .+= dyhi dx end return (NoTangent(), x_ithunk) end - return (yplus, yminus), extrema_pullback + return (ylo, yhi), extrema_pullback end function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) From 4962e4a45531d48a13bd3350b74c997bf8fee692 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 23:24:59 -0400 Subject: [PATCH 03/30] symmetric maximum rule --- src/rulesets/Base/array.jl | 46 ++++++++++++++++++++++++++++++++----- test/rulesets/Base/array.jl | 2 +- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 5356152e6..08810340a 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -406,14 +406,48 @@ function rrule(::typeof(findmin), x::AbstractArray{<:Number}; dims=:) return (y, ind), findmin_pullback end -function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) - (y, _), back = rrule(findmax, x; dims=dims) - maximum_pullback(dy) = back((dy, nothing)) +# These functions pick the same subgradient as findmax: + +# function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) +# (y, _), back = rrule(findmax, x; dims=dims) +# maximum_pullback(dy) = back((dy, nothing)) +# return y, maximum_pullback +# end + +# function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) +# (y, _), back = rrule(findmin, x; dims=dims) +# minimum_pullback(dy) = back((dy, nothing)) +# return y, minimum_pullback +# end + +# These variants pick the symmetric convention: + +function rrule(::typeof(maximum), x::AbstractArray; dims=:) + y = maximum(x; dims=dims) + mask = (y .== x) # allocates & closes over a BitArray thefull size of x + count = sum(mask; dims=dims) # similar allocations to storing ind, if dims=1 etc. + project = ProjectTo(x) + function maximum_pullback(dy) + x_ithunk = InplaceableThunk( + dx -> dx .+= mask .* dy ./ count, + @thunk(project(mask .* dy ./ count),) + ) + return (NoTangent(), x_ithunk) + end return y, maximum_pullback end -function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) - (y, _), back = rrule(findmin, x; dims=dims) - minimum_pullback(dy) = back((dy, nothing)) +function rrule(::typeof(minimum), x::AbstractArray; dims=:) + y = minimum(x; dims=dims) + mask = (y .== x) + count = sum(mask; dims=dims) + project = ProjectTo(x) + function minimum_pullback(dy) + x_ithunk = InplaceableThunk( + dx -> dx .+= mask .* dy ./ count, + @thunk(project(mask .* dy ./ count),) + ) + return (NoTangent(), x_ithunk) + end return y, minimum_pullback end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 4fccadb84..ae5caf5d0 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -206,7 +206,7 @@ end test_rrule(f, rand(3,4), fkwargs=(dims=1,)) test_rrule(f, rand(3,4,5), fkwargs=(dims=(1,3),)) test_rrule(f, rand(1)) # both extrema are the same index - @test_skip test_rrule(f, Float64[1,2,-1,-2,0,2,-2]) # attains max twice -- finite diff picks another subgradient + test_rrule(f, Float64[1,2,-1,-2,0,2,-2]) # attains max twice -- finite diff picks symmetric subgradient end end From a60d244fb1ab4a0eaeb23411a7c01f98a1db0dd5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 23:39:55 -0400 Subject: [PATCH 04/30] promote types by hand --- src/rulesets/Base/array.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 08810340a..55b5ddddb 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -351,12 +351,20 @@ function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) ylo, ilo = findmin(x; dims=dims) yhi, ihi = findmax(x; dims=dims) project = ProjectTo(x) + extrema_pullback(dys::Tuple{AbstractZero, AbstractZero}) = (NoTangent(), NoTangent()) function extrema_pullback((dylo, dyhi)) - T = promote_type(eltype(dylo), eltype(dyhi)) + # T = promote_type(eltype(dylo), eltype(dyhi)) # @show T # often Any, when dyhi == NoTangent() + T = if dylo isa AbstractZero + eltype(dyhi) + elseif dyhi isa AbstractZero + eltype(dylo) + else + promote_type(eltype(dylo), eltype(dyhi)) + end x_thunk = @thunk begin dx = fill!(similar(x, T), false) - view(dx, ilo) .+= dylo + view(dx, ilo) .= dylo view(dx, ihi) .+= dyhi project(dx) end @@ -420,7 +428,8 @@ end # return y, minimum_pullback # end -# These variants pick the symmetric convention: +# These variants pick the symmetric convention, +# they are a bit slower. function rrule(::typeof(maximum), x::AbstractArray; dims=:) y = maximum(x; dims=dims) From 457a3f22dc0cef1efd504106e1a3e4acb3dc7289 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Jul 2021 23:59:40 -0400 Subject: [PATCH 05/30] argmax? --- src/rulesets/Base/array.jl | 11 +++++++++++ src/rulesets/Base/nondiff.jl | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 55b5ddddb..150702f89 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -460,3 +460,14 @@ function rrule(::typeof(minimum), x::AbstractArray; dims=:) end return y, minimum_pullback end + +# function rrule(::typeof(argmax), x::AbstractArray{<:Number}; dims=:) +# argmax_pullback(dy) = (NoTangent(), NoTangent()) +# return argmax(x), argmax_pullback +# end + +# function rrule(::typeof(argmin), x::AbstractArray{<:Number}; dims=:) +# argmin_pullback(dy) = (NoTangent(), NoTangent()) +# return argmin(x), argmin_pullback +# end + diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 0412ac451..8f0c9e7d3 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -96,6 +96,22 @@ @non_differentiable all(::Any, ::Any) @non_differentiable any(::Any) @non_differentiable any(::Any, ::Any) +@non_differentiable argmax(::Any) +@non_differentiable argmin(::Any) +#= + +julia> gradient(argmax, rand(5)) +ERROR: MethodError: Cannot `convert` an object of type Bool to an object of type ChainRulesCore.ZeroTangent +Closest candidates are: + convert(::Type{T}, ::T) where T at essentials.jl:218 +Stacktrace: + [1] fill!(dest::Vector{ChainRulesCore.ZeroTangent}, x::Bool) + @ Base ./array.jl:351 + [2] (::ChainRules.var"#1191#1194"{ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Int64})() + @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/array.jl:312 + [3] unthunk + +=# @non_differentiable ascii(::AbstractString) @non_differentiable axes(::Any) @non_differentiable axes(::Any, ::Any) From 98df2f1999aca87349c8d0ba49d0c3f8a52e5cd8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 28 Jul 2021 12:16:12 -0400 Subject: [PATCH 06/30] allow more zeros --- src/rulesets/Base/array.jl | 4 +++- src/rulesets/Base/nondiff.jl | 14 -------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 150702f89..d09a6b726 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -351,7 +351,7 @@ function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) ylo, ilo = findmin(x; dims=dims) yhi, ihi = findmax(x; dims=dims) project = ProjectTo(x) - extrema_pullback(dys::Tuple{AbstractZero, AbstractZero}) = (NoTangent(), NoTangent()) + extrema_pullback(::Tuple{AbstractZero, AbstractZero}) = (NoTangent(), NoTangent()) function extrema_pullback((dylo, dyhi)) # T = promote_type(eltype(dylo), eltype(dyhi)) # @show T # often Any, when dyhi == NoTangent() @@ -381,6 +381,7 @@ end function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) y, ind = findmax(x; dims=dims) project = ProjectTo(x) + findmax_pullback(::Tuple{AbstractZero, Any}) = (NoTangent(), NoTangent()) function findmax_pullback((dy, _)) x_thunk = @thunk begin dx = fill!(similar(x, eltype(dy)), false) @@ -399,6 +400,7 @@ end function rrule(::typeof(findmin), x::AbstractArray{<:Number}; dims=:) y, ind = findmin(x; dims=dims) project = ProjectTo(x) + findmin_pullback(::Tuple{AbstractZero, Any}) = (NoTangent(), NoTangent()) function findmin_pullback((dy, _)) x_thunk = @thunk begin dx = fill!(similar(x, eltype(dy)), false) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 8f0c9e7d3..dad5533f7 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -98,20 +98,6 @@ @non_differentiable any(::Any, ::Any) @non_differentiable argmax(::Any) @non_differentiable argmin(::Any) -#= - -julia> gradient(argmax, rand(5)) -ERROR: MethodError: Cannot `convert` an object of type Bool to an object of type ChainRulesCore.ZeroTangent -Closest candidates are: - convert(::Type{T}, ::T) where T at essentials.jl:218 -Stacktrace: - [1] fill!(dest::Vector{ChainRulesCore.ZeroTangent}, x::Bool) - @ Base ./array.jl:351 - [2] (::ChainRules.var"#1191#1194"{ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Int64})() - @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/array.jl:312 - [3] unthunk - -=# @non_differentiable ascii(::AbstractString) @non_differentiable axes(::Any) @non_differentiable axes(::Any, ::Any) From 2e63a86495e38bfa8592b40c2f9ec6aac5c0aac5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 28 Jul 2021 12:50:07 -0400 Subject: [PATCH 07/30] upgrade tests --- test/rulesets/Base/array.jl | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index ae5caf5d0..501b87f7c 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -200,13 +200,23 @@ end end @testset "extrema" begin - @testset "$f" for f in [maximum, minimum] - test_rrule(f, rand(10)) - test_rrule(f, rand(3,4)) - test_rrule(f, rand(3,4), fkwargs=(dims=1,)) - test_rrule(f, rand(3,4,5), fkwargs=(dims=(1,3),)) - test_rrule(f, rand(1)) # both extrema are the same index - test_rrule(f, Float64[1,2,-1,-2,0,2,-2]) # attains max twice -- finite diff picks symmetric subgradient - end + test_rrule(extrema, rand(10), output_tangent = (rand(), rand()), check_inferred=false) + @test_skip test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), rand(1,4)), check_inferred=false) # wrong answer? + # Case of both extrema are the same index: + test_rrule(extrema, rand(1), output_tangent = (rand(), rand()), check_inferred=false) +end + +@testset "$f" for f in [findmax, findmin] + @test_skip test_rrule(f, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error? + @test_skip test_rrule(f, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), 999), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? end +@testset "$f" for f in [maximum, minimum] + test_rrule(f, rand(10)) + test_rrule(f, rand(3,4)) + test_rrule(f, rand(3,4), fkwargs=(dims=1,)) + test_rrule(f, rand(3,4,5), fkwargs=(dims=(1,3),)) + # Case which attains max twice -- finite diff picks symmetric subgradient + test_rrule(f, Float64[1,2,-1,-2,0,2,-2]) + @test_skip test_rrule(f, Float64[1,2,-1,-2,0,2,-2,2,-2]) # ... three times, fails! +end From 47d98222c6ebce20c82a728a5b4b5778e26ebe8b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 18:56:33 -0400 Subject: [PATCH 08/30] don't do symmetric convention --- src/rulesets/Base/array.jl | 81 ++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 46 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index d09a6b726..dd91d6346 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -418,58 +418,47 @@ end # These functions pick the same subgradient as findmax: -# function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) -# (y, _), back = rrule(findmax, x; dims=dims) -# maximum_pullback(dy) = back((dy, nothing)) -# return y, maximum_pullback -# end - -# function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) -# (y, _), back = rrule(findmin, x; dims=dims) -# minimum_pullback(dy) = back((dy, nothing)) -# return y, minimum_pullback -# end - -# These variants pick the symmetric convention, -# they are a bit slower. - -function rrule(::typeof(maximum), x::AbstractArray; dims=:) - y = maximum(x; dims=dims) - mask = (y .== x) # allocates & closes over a BitArray thefull size of x - count = sum(mask; dims=dims) # similar allocations to storing ind, if dims=1 etc. - project = ProjectTo(x) - function maximum_pullback(dy) - x_ithunk = InplaceableThunk( - dx -> dx .+= mask .* dy ./ count, - @thunk(project(mask .* dy ./ count),) - ) - return (NoTangent(), x_ithunk) - end +function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) + (y, _), back = rrule(findmax, x; dims=dims) + maximum_pullback(dy) = back((dy, nothing)) return y, maximum_pullback end -function rrule(::typeof(minimum), x::AbstractArray; dims=:) - y = minimum(x; dims=dims) - mask = (y .== x) - count = sum(mask; dims=dims) - project = ProjectTo(x) - function minimum_pullback(dy) - x_ithunk = InplaceableThunk( - dx -> dx .+= mask .* dy ./ count, - @thunk(project(mask .* dy ./ count),) - ) - return (NoTangent(), x_ithunk) - end +function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) + (y, _), back = rrule(findmin, x; dims=dims) + minimum_pullback(dy) = back((dy, nothing)) return y, minimum_pullback end -# function rrule(::typeof(argmax), x::AbstractArray{<:Number}; dims=:) -# argmax_pullback(dy) = (NoTangent(), NoTangent()) -# return argmax(x), argmax_pullback -# end +# These variants pick the symmetric convention, +# they are a bit slower. -# function rrule(::typeof(argmin), x::AbstractArray{<:Number}; dims=:) -# argmin_pullback(dy) = (NoTangent(), NoTangent()) -# return argmin(x), argmin_pullback +# function rrule(::typeof(maximum), x::AbstractArray; dims=:) +# y = maximum(x; dims=dims) +# mask = (y .== x) # allocates & closes over a BitArray thefull size of x +# count = sum(mask; dims=dims) # similar allocations to storing ind, if dims=1 etc. +# project = ProjectTo(x) +# function maximum_pullback(dy) +# x_ithunk = InplaceableThunk( +# dx -> dx .+= mask .* dy ./ count, +# @thunk(project(mask .* dy ./ count),) +# ) +# return (NoTangent(), x_ithunk) +# end +# return y, maximum_pullback # end +# function rrule(::typeof(minimum), x::AbstractArray; dims=:) +# y = minimum(x; dims=dims) +# mask = (y .== x) +# count = sum(mask; dims=dims) +# project = ProjectTo(x) +# function minimum_pullback(dy) +# x_ithunk = InplaceableThunk( +# dx -> dx .+= mask .* dy ./ count, +# @thunk(project(mask .* dy ./ count),) +# ) +# return (NoTangent(), x_ithunk) +# end +# return y, minimum_pullback +# end From 8eb4b1caec45ecee139e9b63001d5e023eea5314 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 18:56:38 -0400 Subject: [PATCH 09/30] tests --- test/rulesets/Base/array.jl | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 501b87f7c..df7b950dc 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -206,17 +206,26 @@ end test_rrule(extrema, rand(1), output_tangent = (rand(), rand()), check_inferred=false) end -@testset "$f" for f in [findmax, findmin] - @test_skip test_rrule(f, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error? - @test_skip test_rrule(f, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), 999), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? -end - -@testset "$f" for f in [maximum, minimum] - test_rrule(f, rand(10)) - test_rrule(f, rand(3,4)) - test_rrule(f, rand(3,4), fkwargs=(dims=1,)) - test_rrule(f, rand(3,4,5), fkwargs=(dims=(1,3),)) - # Case which attains max twice -- finite diff picks symmetric subgradient - test_rrule(f, Float64[1,2,-1,-2,0,2,-2]) - @test_skip test_rrule(f, Float64[1,2,-1,-2,0,2,-2,2,-2]) # ... three times, fails! +@testset "$findm" for findm in [findmax, findmin] + @test_skip test_rrule(findm, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error? + @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), 999), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? +end + @test rrule(findmax, [1,2,33])[1] == (33, 3) + @test rrule(findmin, [11,22,33])[1] == (11, 1) + + @test [0,0,1] == @inferred unthunk(rrule(findmax, [1,2,3])[2]((1.0, nothing))[2]) + @test [1,0,0] == @inferred unthunk(rrule(findmin, [1,2,3])[2]((1.0, nothing))[2]) + + @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5, nothing))[2]) + @test [5 0; 0 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4])[2]((5, nothing))[2]) + + +@testset "$imum" for imum in [maximum, minimum] + test_rrule(imum, rand(10)) + test_rrule(imum, rand(3,4)) + test_rrule(imum, rand(3,4), fkwargs=(dims=1,)) + test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),)) + # Case which attains max twice: + @test_skip test_rrule(imum, Float64[1,2,-1,-2,0,2,-2]) # finite diff picks symmetric subgradient? + @test_skip test_rrule(imum, Float64[1,2,-1,-2,0,2,-2,2,-2]) # finite diff does something else. end From 2df968162db84fc977074f198111250c6787b24f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 21:01:38 -0400 Subject: [PATCH 10/30] fix 1.0 --- src/rulesets/Base/array.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index dd91d6346..0f8b98170 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -365,12 +365,13 @@ function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) x_thunk = @thunk begin dx = fill!(similar(x, T), false) view(dx, ilo) .= dylo - view(dx, ihi) .+= dyhi + # view(dx, ihi) .+= dyhi # illegal on 1.0! + view(dx, ihi) .= view(dx, ihi) .+ dyhi project(dx) end x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ilo) .+= dylo - view(dx, ihi) .+= dyhi + view(dx, ilo) .= view(dx, ilo) .+ dylo + view(dx, ihi) .= view(dx, ihi) .+ dyhi dx end return (NoTangent(), x_ithunk) @@ -389,7 +390,7 @@ function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) project(dx) end x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ind) .+= dy + view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 dx end return (NoTangent(), x_ithunk) @@ -408,7 +409,7 @@ function rrule(::typeof(findmin), x::AbstractArray{<:Number}; dims=:) project(dx) end x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ind) .+= dy + view(dx, ind) .= view(dx, ind) .+ dy dx end return (NoTangent(), x_ithunk) From 9fdfec3dc066e378a3d57c2723524c2e1a00725b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 21:33:29 -0400 Subject: [PATCH 11/30] rm symmetric versions --- src/rulesets/Base/array.jl | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 0f8b98170..6f4928c07 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -431,35 +431,3 @@ function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) return y, minimum_pullback end -# These variants pick the symmetric convention, -# they are a bit slower. - -# function rrule(::typeof(maximum), x::AbstractArray; dims=:) -# y = maximum(x; dims=dims) -# mask = (y .== x) # allocates & closes over a BitArray thefull size of x -# count = sum(mask; dims=dims) # similar allocations to storing ind, if dims=1 etc. -# project = ProjectTo(x) -# function maximum_pullback(dy) -# x_ithunk = InplaceableThunk( -# dx -> dx .+= mask .* dy ./ count, -# @thunk(project(mask .* dy ./ count),) -# ) -# return (NoTangent(), x_ithunk) -# end -# return y, maximum_pullback -# end - -# function rrule(::typeof(minimum), x::AbstractArray; dims=:) -# y = minimum(x; dims=dims) -# mask = (y .== x) -# count = sum(mask; dims=dims) -# project = ProjectTo(x) -# function minimum_pullback(dy) -# x_ithunk = InplaceableThunk( -# dx -> dx .+= mask .* dy ./ count, -# @thunk(project(mask .* dy ./ count),) -# ) -# return (NoTangent(), x_ithunk) -# end -# return y, minimum_pullback -# end From bf0bf2e9d7b0643ad0f9fd8f9dfce2a93a07f343 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 21:34:10 -0400 Subject: [PATCH 12/30] move extrema to last --- src/rulesets/Base/array.jl | 65 ++++++++++++++++++------------------- test/rulesets/Base/array.jl | 14 ++++---- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 6f4928c07..22eab01ec 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -344,41 +344,9 @@ function rrule(::typeof(fill), x::Any, dims...) end ##### -##### `extrema`, `findmax`, `maximum`, etc. +##### `findmax`, `maximum`, `extrema`, etc. ##### -function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) - ylo, ilo = findmin(x; dims=dims) - yhi, ihi = findmax(x; dims=dims) - project = ProjectTo(x) - extrema_pullback(::Tuple{AbstractZero, AbstractZero}) = (NoTangent(), NoTangent()) - function extrema_pullback((dylo, dyhi)) - # T = promote_type(eltype(dylo), eltype(dyhi)) - # @show T # often Any, when dyhi == NoTangent() - T = if dylo isa AbstractZero - eltype(dyhi) - elseif dyhi isa AbstractZero - eltype(dylo) - else - promote_type(eltype(dylo), eltype(dyhi)) - end - x_thunk = @thunk begin - dx = fill!(similar(x, T), false) - view(dx, ilo) .= dylo - # view(dx, ihi) .+= dyhi # illegal on 1.0! - view(dx, ihi) .= view(dx, ihi) .+ dyhi - project(dx) - end - x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ilo) .= view(dx, ilo) .+ dylo - view(dx, ihi) .= view(dx, ihi) .+ dyhi - dx - end - return (NoTangent(), x_ithunk) - end - return (ylo, yhi), extrema_pullback -end - function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) y, ind = findmax(x; dims=dims) project = ProjectTo(x) @@ -431,3 +399,34 @@ function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) return y, minimum_pullback end +function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) + ylo, ilo = findmin(x; dims=dims) + yhi, ihi = findmax(x; dims=dims) + project = ProjectTo(x) + extrema_pullback(::Tuple{AbstractZero, AbstractZero}) = (NoTangent(), NoTangent()) + function extrema_pullback((dylo, dyhi)) + # T = promote_type(eltype(dylo), eltype(dyhi)) + # @show T # often Any, when dyhi == NoTangent() + T = if dylo isa AbstractZero + eltype(dyhi) + elseif dyhi isa AbstractZero + eltype(dylo) + else + promote_type(eltype(dylo), eltype(dyhi)) + end + x_thunk = @thunk begin + dx = fill!(similar(x, T), false) + view(dx, ilo) .= dylo + # view(dx, ihi) .+= dyhi # illegal on 1.0! + view(dx, ihi) .= view(dx, ihi) .+ dyhi + project(dx) + end + x_ithunk = InplaceableThunk(x_thunk) do dx + view(dx, ilo) .= view(dx, ilo) .+ dylo + view(dx, ihi) .= view(dx, ihi) .+ dyhi + dx + end + return (NoTangent(), x_ithunk) + end + return (ylo, yhi), extrema_pullback +end \ No newline at end of file diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index df7b950dc..c4000bdea 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -199,13 +199,6 @@ end test_rrule(fill, 3.3, (3, 3, 3)) end -@testset "extrema" begin - test_rrule(extrema, rand(10), output_tangent = (rand(), rand()), check_inferred=false) - @test_skip test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), rand(1,4)), check_inferred=false) # wrong answer? - # Case of both extrema are the same index: - test_rrule(extrema, rand(1), output_tangent = (rand(), rand()), check_inferred=false) -end - @testset "$findm" for findm in [findmax, findmin] @test_skip test_rrule(findm, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error? @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), 999), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? @@ -229,3 +222,10 @@ end @test_skip test_rrule(imum, Float64[1,2,-1,-2,0,2,-2]) # finite diff picks symmetric subgradient? @test_skip test_rrule(imum, Float64[1,2,-1,-2,0,2,-2,2,-2]) # finite diff does something else. end + +@testset "extrema" begin + test_rrule(extrema, rand(10), output_tangent = (rand(), rand()), check_inferred=false) + @test_skip test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), rand(1,4)), check_inferred=false) # wrong answer? + # Case of both extrema are the same index: + test_rrule(extrema, rand(1), output_tangent = (rand(), rand()), check_inferred=false) +end From 5c57f44401c0a6b368aa010b6b6b53c775381806 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 22:07:06 -0400 Subject: [PATCH 13/30] tidy --- src/rulesets/Base/array.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 22eab01ec..40752f389 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -350,7 +350,6 @@ end function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) y, ind = findmax(x; dims=dims) project = ProjectTo(x) - findmax_pullback(::Tuple{AbstractZero, Any}) = (NoTangent(), NoTangent()) function findmax_pullback((dy, _)) x_thunk = @thunk begin dx = fill!(similar(x, eltype(dy)), false) @@ -363,13 +362,15 @@ function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) end return (NoTangent(), x_ithunk) end + function findmax_pullback(::Tuple{AbstractZero, Any}) + return (NoTangent(), NoTangent()) + end return (y, ind), findmax_pullback end function rrule(::typeof(findmin), x::AbstractArray{<:Number}; dims=:) y, ind = findmin(x; dims=dims) project = ProjectTo(x) - findmin_pullback(::Tuple{AbstractZero, Any}) = (NoTangent(), NoTangent()) function findmin_pullback((dy, _)) x_thunk = @thunk begin dx = fill!(similar(x, eltype(dy)), false) @@ -382,6 +383,9 @@ function rrule(::typeof(findmin), x::AbstractArray{<:Number}; dims=:) end return (NoTangent(), x_ithunk) end + function findmin_pullback(::Tuple{AbstractZero, Any}) + return (NoTangent(), NoTangent()) + end return (y, ind), findmin_pullback end @@ -403,10 +407,7 @@ function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) ylo, ilo = findmin(x; dims=dims) yhi, ihi = findmax(x; dims=dims) project = ProjectTo(x) - extrema_pullback(::Tuple{AbstractZero, AbstractZero}) = (NoTangent(), NoTangent()) function extrema_pullback((dylo, dyhi)) - # T = promote_type(eltype(dylo), eltype(dyhi)) - # @show T # often Any, when dyhi == NoTangent() T = if dylo isa AbstractZero eltype(dyhi) elseif dyhi isa AbstractZero @@ -428,5 +429,8 @@ function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) end return (NoTangent(), x_ithunk) end + function extrema_pullback(::Tuple{AbstractZero, AbstractZero}) + return (NoTangent(), NoTangent()) + end return (ylo, yhi), extrema_pullback end \ No newline at end of file From 1ba053a4fc055f1965f00dadb92930b1e28d8bec Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 11:20:46 -0400 Subject: [PATCH 14/30] fixup extrema --- src/rulesets/Base/array.jl | 64 +++++++++++++++++++++++++++---------- test/rulesets/Base/array.jl | 15 +++++---- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 40752f389..d3907a5a4 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -404,33 +404,63 @@ function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) end function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) - ylo, ilo = findmin(x; dims=dims) - yhi, ihi = findmax(x; dims=dims) + if dims isa Colon + return _extrema_colon(x) + else + return _extrema_dims(x, dims) + end +end + +function _extrema_colon(x) + ylo, ilo = findmin(x) + yhi, ihi = findmax(x) project = ProjectTo(x) function extrema_pullback((dylo, dyhi)) - T = if dylo isa AbstractZero - eltype(dyhi) - elseif dyhi isa AbstractZero - eltype(dylo) - else - promote_type(eltype(dylo), eltype(dyhi)) - end - x_thunk = @thunk begin + # One argument may be AbstractZero here, promote_type allows *, hence gives Any: + T = Base.promote_op(+, typeof(dylo), typeof(dyhi)) + x_nothunk = let + # x_thunk = @thunk begin # this doesn't infer dx = fill!(similar(x, T), false) view(dx, ilo) .= dylo - # view(dx, ihi) .+= dyhi # illegal on 1.0! view(dx, ihi) .= view(dx, ihi) .+ dyhi project(dx) end - x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ilo) .= view(dx, ilo) .+ dylo - view(dx, ihi) .= view(dx, ihi) .+ dyhi - dx - end - return (NoTangent(), x_ithunk) + # x_ithunk = InplaceableThunk(x_thunk) do dx + # view(dx, ilo) .= view(dx, ilo) .+ dylo + # view(dx, ihi) .= view(dx, ihi) .+ dyhi + # dx + # end + return (NoTangent(), x_nothunk) end function extrema_pullback(::Tuple{AbstractZero, AbstractZero}) return (NoTangent(), NoTangent()) end return (ylo, yhi), extrema_pullback +end + +function _extrema_dims(x, dims) + ylo, ilo = findmin(x; dims=dims) + yhi, ihi = findmax(x; dims=dims) + y = similar(ylo, Tuple{eltype(ylo), eltype(yhi)}) + map!(tuple, y, ylo, yhi) # this is a GPU-friendly version of collect(zip(ylo, yhi)) + project = ProjectTo(x) + function extrema_pullback_dims(dy_raw) + dy = unthunk(dy_raw) + @assert dy isa AbstractArray{<:Tuple{Any,Any}} + T = Base.promote_op(+, eltype(dy).parameters...) # can we actually get Array{Tuple{Float64,ZeroTangent}} here? + x_nothunk = let + # x_thunk = @thunk begin # this doesn't infer + dx = fill!(similar(x, T), false) + view(dx, ilo) .= first.(dy) + view(dx, ihi) .= view(dx, ihi) .+ last.(dy) + project(dx) + end + # x_ithunk = InplaceableThunk(x_thunk) do dx + # view(dx, ilo) .= first.(dy) + # view(dx, ihi) .= view(dx, ihi) .+ last.(dy) + # dx + # end + return (NoTangent(), x_nothunk) + end + return y, extrema_pullback_dims end \ No newline at end of file diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index c4000bdea..20cd74919 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -218,14 +218,15 @@ end test_rrule(imum, rand(3,4)) test_rrule(imum, rand(3,4), fkwargs=(dims=1,)) test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),)) - # Case which attains max twice: - @test_skip test_rrule(imum, Float64[1,2,-1,-2,0,2,-2]) # finite diff picks symmetric subgradient? - @test_skip test_rrule(imum, Float64[1,2,-1,-2,0,2,-2,2,-2]) # finite diff does something else. + # Case which attains max twice -- can't use FiniteDifferences for this + res = imum == maximum ? [0,1,0,0,0,0] : [1,0,0,0,0,0] + @test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2]) end @testset "extrema" begin - test_rrule(extrema, rand(10), output_tangent = (rand(), rand()), check_inferred=false) - @test_skip test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), rand(1,4)), check_inferred=false) # wrong answer? - # Case of both extrema are the same index: - test_rrule(extrema, rand(1), output_tangent = (rand(), rand()), check_inferred=false) + test_rrule(extrema, rand(10), output_tangent = (rand(), rand())) + test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = collect(zip(rand(1,4), rand(1,4)))) # wrong answer? + # Case of both extrema are the same index, to check accumulation: + test_rrule(extrema, rand(1), output_tangent = (rand(), rand())) + test_rrule(extrema, rand(1,1), fkwargs=(dims=2,), output_tangent = hcat((rand(), rand()))) end From 15ccc56962c0a3dba0b70b622e60e85970089bd4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 11:45:40 -0400 Subject: [PATCH 15/30] tests --- src/rulesets/Base/array.jl | 2 +- test/rulesets/Base/array.jl | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index d3907a5a4..a300352da 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -463,4 +463,4 @@ function _extrema_dims(x, dims) return (NoTangent(), x_nothunk) end return y, extrema_pullback_dims -end \ No newline at end of file +end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 20cd74919..fa6af19cb 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -200,8 +200,10 @@ end end @testset "$findm" for findm in [findmax, findmin] - @test_skip test_rrule(findm, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error? - @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), 999), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? + @test_skip test_rrule(findm, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error from FiniteDifferences + test_rrule(findm, rand(10), output_tangent = (rand(), false)) + @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? + @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") end @test rrule(findmax, [1,2,33])[1] == (33, 3) @test rrule(findmin, [11,22,33])[1] == (11, 1) @@ -225,8 +227,14 @@ end @testset "extrema" begin test_rrule(extrema, rand(10), output_tangent = (rand(), rand())) - test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = collect(zip(rand(1,4), rand(1,4)))) # wrong answer? - # Case of both extrema are the same index, to check accumulation: + test_rrule(extrema, rand(3,4), fkwargs=(dims=1,), output_tangent = collect(zip(rand(1,4), rand(1,4)))) + # Case where both extrema are the same index, to check accumulation: test_rrule(extrema, rand(1), output_tangent = (rand(), rand())) test_rrule(extrema, rand(1,1), fkwargs=(dims=2,), output_tangent = hcat((rand(), rand()))) + test_rrule(extrema, rand(3,1), fkwargs=(dims=2,), output_tangent = collect(zip(rand(3,1), rand(3,1)))) + # Double-check the forward pass + A = randn(3,4,5) + @test extrema(A, dims=(1,3)) == rrule(extrema, A, dims=(1,3))[1] + B = hcat(A[:,:,1], A[:,:,1]) + @test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1] end From 7ef9a32637e47054a1fe0426ccabb909afdae7a7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 11:48:36 -0400 Subject: [PATCH 16/30] tests --- test/rulesets/Base/array.jl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index fa6af19cb..ac0c00a5a 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -205,15 +205,8 @@ end @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") end - @test rrule(findmax, [1,2,33])[1] == (33, 3) - @test rrule(findmin, [11,22,33])[1] == (11, 1) - - @test [0,0,1] == @inferred unthunk(rrule(findmax, [1,2,3])[2]((1.0, nothing))[2]) - @test [1,0,0] == @inferred unthunk(rrule(findmin, [1,2,3])[2]((1.0, nothing))[2]) - - @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5, nothing))[2]) - @test [5 0; 0 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4])[2]((5, nothing))[2]) - + @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) + @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) @testset "$imum" for imum in [maximum, minimum] test_rrule(imum, rand(10)) From 3af83c00395996a372f3d6d4372abd1390879d04 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 12:21:26 -0400 Subject: [PATCH 17/30] use eval loop, tidy, tests --- src/rulesets/Base/array.jl | 62 +++++++++++++++---------------------- test/rulesets/Base/array.jl | 17 ++++++---- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index a300352da..3226e7a7a 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -344,49 +344,33 @@ function rrule(::typeof(fill), x::Any, dims...) end ##### -##### `findmax`, `maximum`, `extrema`, etc. +##### `findmax`, `maximum`, etc. ##### -function rrule(::typeof(findmax), x::AbstractArray{<:Number}; dims=:) - y, ind = findmax(x; dims=dims) - project = ProjectTo(x) - function findmax_pullback((dy, _)) - x_thunk = @thunk begin - dx = fill!(similar(x, eltype(dy)), false) - view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array cases - project(dx) +for findm in (:findmin, :findmax) + findm_pullback = Symbol(findm, :_pullback) + + @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) + y, ind = $findm(x; dims=dims) + project = ProjectTo(x) + function $findm_pullback((dy, _)) + x_thunk = @thunk begin + dx = fill!(similar(x, eltype(dy)), false) + view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array cases + project(dx) + end + x_ithunk = InplaceableThunk(x_thunk) do dx + view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 + dx + end + return (NoTangent(), x_ithunk) end - x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 - dx + function findmax_pullback(::Tuple{AbstractZero, Any}) + return (NoTangent(), NoTangent()) end - return (NoTangent(), x_ithunk) + return (y, ind), $findm_pullback end - function findmax_pullback(::Tuple{AbstractZero, Any}) - return (NoTangent(), NoTangent()) - end - return (y, ind), findmax_pullback -end -function rrule(::typeof(findmin), x::AbstractArray{<:Number}; dims=:) - y, ind = findmin(x; dims=dims) - project = ProjectTo(x) - function findmin_pullback((dy, _)) - x_thunk = @thunk begin - dx = fill!(similar(x, eltype(dy)), false) - view(dx, ind) .= dy - project(dx) - end - x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ind) .= view(dx, ind) .+ dy - dx - end - return (NoTangent(), x_ithunk) - end - function findmin_pullback(::Tuple{AbstractZero, Any}) - return (NoTangent(), NoTangent()) - end - return (y, ind), findmin_pullback end # These functions pick the same subgradient as findmax: @@ -403,6 +387,10 @@ function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) return y, minimum_pullback end +##### +##### `extrema` +##### + function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) if dims isa Colon return _extrema_colon(x) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index ac0c00a5a..e24d440eb 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -199,14 +199,19 @@ end test_rrule(fill, 3.3, (3, 3, 3)) end -@testset "$findm" for findm in [findmax, findmin] - @test_skip test_rrule(findm, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error from FiniteDifferences - test_rrule(findm, rand(10), output_tangent = (rand(), false)) - @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? - @test_skip test_rrule(findm, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") -end +@testset "findmin & findmax" begin + test_rrule(findmin, rand(10), output_tangent = (rand(), false)) + test_rrule(findmax, rand(10), output_tangent = (rand(), false)) + @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) + @test_skip test_rrule(findmin, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error from FiniteDifferences + @test_skip test_rrule(findmax, rand(5,3), output_tangent = (rand(), false), check_inferred=false) # error from FiniteDifferences + + # With dims: @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) + @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? + @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") +end @testset "$imum" for imum in [maximum, minimum] test_rrule(imum, rand(10)) From 8fca98aa8e352301583a841e733ad6a9b0fe087c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 12:30:20 -0400 Subject: [PATCH 18/30] forward rules for maximum --- src/rulesets/Base/array.jl | 21 ++++++++++++++++++--- test/rulesets/Base/array.jl | 4 ++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 3226e7a7a..ba4e12b41 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -350,13 +350,18 @@ end for findm in (:findmin, :findmax) findm_pullback = Symbol(findm, :_pullback) + # @eval function frule((_, xdot), ::typeof($findm), x; dims=:) + # y, ind = $findm(x; dims=dims) + # return (y, ind), (xdot[ind], NoTangent()) # needs to be some Tangent? + # end + @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) y, ind = $findm(x; dims=dims) project = ProjectTo(x) function $findm_pullback((dy, _)) x_thunk = @thunk begin dx = fill!(similar(x, eltype(dy)), false) - view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array cases + view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray project(dx) end x_ithunk = InplaceableThunk(x_thunk) do dx @@ -365,7 +370,7 @@ for findm in (:findmin, :findmax) end return (NoTangent(), x_ithunk) end - function findmax_pullback(::Tuple{AbstractZero, Any}) + function $findm_pullback(::Tuple{AbstractZero, Any}) return (NoTangent(), NoTangent()) end return (y, ind), $findm_pullback @@ -373,7 +378,12 @@ for findm in (:findmin, :findmax) end -# These functions pick the same subgradient as findmax: +# These rules for maximum pick the same subgradient as findmax: + +function frule((_, xdot), ::typeof(maximum), x; dims=:) + y, ind = findmax(x; dims=dims) + return y, xdot[ind] +end function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) (y, _), back = rrule(findmax, x; dims=dims) @@ -381,6 +391,11 @@ function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) return y, maximum_pullback end +function frule((_, xdot), ::typeof(minimum), x; dims=:) + y, ind = findmin(x; dims=dims) + return y, xdot[ind] +end + function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) (y, _), back = rrule(findmin, x; dims=dims) minimum_pullback(dy) = back((dy, nothing)) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index e24d440eb..8a3e73374 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -214,6 +214,10 @@ end end @testset "$imum" for imum in [maximum, minimum] + # Forward + test_frule(imum, rand(10)) + test_frule(imum, rand(3,4), fkwargs=(dims=1,)) + # Reverse test_rrule(imum, rand(10)) test_rrule(imum, rand(3,4)) test_rrule(imum, rand(3,4), fkwargs=(dims=1,)) From ee7d97f4c04912960735bca5703d7aed8249341e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 12:37:54 -0400 Subject: [PATCH 19/30] frules for findmax --- src/rulesets/Base/array.jl | 9 +++++---- test/rulesets/Base/array.jl | 10 +++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index ba4e12b41..4b942a695 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -350,10 +350,11 @@ end for findm in (:findmin, :findmax) findm_pullback = Symbol(findm, :_pullback) - # @eval function frule((_, xdot), ::typeof($findm), x; dims=:) - # y, ind = $findm(x; dims=dims) - # return (y, ind), (xdot[ind], NoTangent()) # needs to be some Tangent? - # end + @eval function frule((_, xdot), ::typeof($findm), x; dims=:) + y, ind = $findm(x; dims=dims) + ydot = (xdot[ind], NoTangent()) + return (y, ind), Tangent{typeof((y, ind)),typeof(ydot)}(ydot) + end @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) y, ind = $findm(x; dims=dims) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 8a3e73374..4f24e5fb2 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -200,13 +200,20 @@ end end @testset "findmin & findmax" begin + # Forward + test_frule(findmin, rand(10)) + test_frule(findmax, rand(10)) + @test_skip test_frule(findmin, rand(3,4)) # StackOverflowError, TypeError: in new, expected Tuple{Int64, Int64}, got a value of type Tuple{Float64, Float64} + @test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,)) + + # Reverse test_rrule(findmin, rand(10), output_tangent = (rand(), false)) test_rrule(findmax, rand(10), output_tangent = (rand(), false)) @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) @test_skip test_rrule(findmin, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error from FiniteDifferences @test_skip test_rrule(findmax, rand(5,3), output_tangent = (rand(), false), check_inferred=false) # error from FiniteDifferences - # With dims: + # Reverse with dims: @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? @@ -216,6 +223,7 @@ end @testset "$imum" for imum in [maximum, minimum] # Forward test_frule(imum, rand(10)) + test_frule(imum, rand(3,4)) test_frule(imum, rand(3,4), fkwargs=(dims=1,)) # Reverse test_rrule(imum, rand(10)) From 67ea14db643f17fa8a6b0d27e95d006dbfb2fb45 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 31 Jul 2021 15:29:25 -0400 Subject: [PATCH 20/30] tidy --- src/rulesets/Base/array.jl | 3 +-- test/rulesets/Base/array.jl | 8 +++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 4b942a695..9f42b5ac5 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -352,8 +352,7 @@ for findm in (:findmin, :findmax) @eval function frule((_, xdot), ::typeof($findm), x; dims=:) y, ind = $findm(x; dims=dims) - ydot = (xdot[ind], NoTangent()) - return (y, ind), Tangent{typeof((y, ind)),typeof(ydot)}(ydot) + return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent()) end @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 4f24e5fb2..45f444454 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -203,15 +203,17 @@ end # Forward test_frule(findmin, rand(10)) test_frule(findmax, rand(10)) - @test_skip test_frule(findmin, rand(3,4)) # StackOverflowError, TypeError: in new, expected Tuple{Int64, Int64}, got a value of type Tuple{Float64, Float64} + @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4))) isa Tuple{Tuple{Float64, CartesianIndex}, Tangent} + @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4), dims=1)) isa Tuple{Tuple{Matrix, Matrix}, Tangent} + @test_skip test_frule(findmin, rand(3,4)) # StackOverflowError, CartesianIndex{2}(index::Tuple{Float64, Float64}) (repeats 79984 times) & TypeError: in new, expected Tuple{Int64, Int64}, got a value of type Tuple{Float64, Float64} @test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,)) # Reverse test_rrule(findmin, rand(10), output_tangent = (rand(), false)) test_rrule(findmax, rand(10), output_tangent = (rand(), false)) @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) - @test_skip test_rrule(findmin, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # error from FiniteDifferences - @test_skip test_rrule(findmax, rand(5,3), output_tangent = (rand(), false), check_inferred=false) # error from FiniteDifferences + @test_skip test_rrule(findmin, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # DimensionMismatch from FiniteDifferences + @test_skip test_rrule(findmax, rand(5,3), output_tangent = (rand(), false), check_inferred=false) # DimensionMismatch from FiniteDifferences # Reverse with dims: @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) From b17be68d8750b4b7a7de2f7495958143ec53077c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 2 Aug 2021 11:02:42 -0400 Subject: [PATCH 21/30] widen similar to ensure writeability --- src/rulesets/Base/array.jl | 8 ++++---- test/rulesets/Base/array.jl | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 9f42b5ac5..edd83484a 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -358,9 +358,9 @@ for findm in (:findmin, :findmax) @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) y, ind = $findm(x; dims=dims) project = ProjectTo(x) - function $findm_pullback((dy, _)) + function $findm_pullback((dy, _)) # this accept e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) x_thunk = @thunk begin - dx = fill!(similar(x, eltype(dy)), false) + dx = fill!(similar(x, eltype(dy), axes(x)), false) view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray project(dx) end @@ -423,7 +423,7 @@ function _extrema_colon(x) T = Base.promote_op(+, typeof(dylo), typeof(dyhi)) x_nothunk = let # x_thunk = @thunk begin # this doesn't infer - dx = fill!(similar(x, T), false) + dx = fill!(similar(x, T, axes(x)), false) view(dx, ilo) .= dylo view(dx, ihi) .= view(dx, ihi) .+ dyhi project(dx) @@ -453,7 +453,7 @@ function _extrema_dims(x, dims) T = Base.promote_op(+, eltype(dy).parameters...) # can we actually get Array{Tuple{Float64,ZeroTangent}} here? x_nothunk = let # x_thunk = @thunk begin # this doesn't infer - dx = fill!(similar(x, T), false) + dx = fill!(similar(x, T, axes(x)), false) view(dx, ilo) .= first.(dy) view(dx, ihi) .= view(dx, ihi) .+ last.(dy) project(dx) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 45f444454..eae441525 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -227,14 +227,21 @@ end test_frule(imum, rand(10)) test_frule(imum, rand(3,4)) test_frule(imum, rand(3,4), fkwargs=(dims=1,)) + # Reverse test_rrule(imum, rand(10)) test_rrule(imum, rand(3,4)) test_rrule(imum, rand(3,4), fkwargs=(dims=1,)) test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),)) + # Case which attains max twice -- can't use FiniteDifferences for this res = imum == maximum ? [0,1,0,0,0,0] : [1,0,0,0,0,0] @test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2]) + + # Structured matrix -- NB the minimum is a structral zero here + @test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal + @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular + @test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool) end @testset "extrema" begin From 8d96687cf5e32d7fe7aed6ca58f63f352c85d9ef Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 2 Aug 2021 11:17:33 -0400 Subject: [PATCH 22/30] comments --- src/rulesets/Base/array.jl | 9 +++++++-- test/rulesets/Base/array.jl | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index edd83484a..d463f158f 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -358,8 +358,11 @@ for findm in (:findmin, :findmax) @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) y, ind = $findm(x; dims=dims) project = ProjectTo(x) + # This pullback is a lot like the one for getindex. Ideally they would probably be combined? function $findm_pullback((dy, _)) # this accept e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) x_thunk = @thunk begin + # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't + # allow `eltype(dy)`, nor does it work for many structured matrices. dx = fill!(similar(x, eltype(dy), axes(x)), false) view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray project(dx) @@ -419,7 +422,8 @@ function _extrema_colon(x) yhi, ihi = findmax(x) project = ProjectTo(x) function extrema_pullback((dylo, dyhi)) - # One argument may be AbstractZero here, promote_type allows *, hence gives Any: + # One argument may be AbstractZero here. Use promote_op because + # promote_type allows for * as well as +, hence gives Any. T = Base.promote_op(+, typeof(dylo), typeof(dyhi)) x_nothunk = let # x_thunk = @thunk begin # this doesn't infer @@ -450,7 +454,8 @@ function _extrema_dims(x, dims) function extrema_pullback_dims(dy_raw) dy = unthunk(dy_raw) @assert dy isa AbstractArray{<:Tuple{Any,Any}} - T = Base.promote_op(+, eltype(dy).parameters...) # can we actually get Array{Tuple{Float64,ZeroTangent}} here? + # Can we actually get Array{Tuple{Float64,ZeroTangent}} here? Not sure. + T = Base.promote_op(+, eltype(dy).parameters...) x_nothunk = let # x_thunk = @thunk begin # this doesn't infer dx = fill!(similar(x, T, axes(x)), false) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index eae441525..6445eac3f 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -206,6 +206,7 @@ end @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4))) isa Tuple{Tuple{Float64, CartesianIndex}, Tangent} @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4), dims=1)) isa Tuple{Tuple{Matrix, Matrix}, Tangent} @test_skip test_frule(findmin, rand(3,4)) # StackOverflowError, CartesianIndex{2}(index::Tuple{Float64, Float64}) (repeats 79984 times) & TypeError: in new, expected Tuple{Int64, Int64}, got a value of type Tuple{Float64, Float64} + @test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent())) @test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,)) # Reverse From 5bd06ada10cdca73f187150d65b973b088d23d31 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 2 Aug 2021 11:49:16 -0400 Subject: [PATCH 23/30] dispatch -> branch --- src/rulesets/Base/array.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index d463f158f..a94d40ca5 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -360,6 +360,7 @@ for findm in (:findmin, :findmax) project = ProjectTo(x) # This pullback is a lot like the one for getindex. Ideally they would probably be combined? function $findm_pullback((dy, _)) # this accept e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) + dy isa AbstractZero && return (NoTangent(), NoTangent()) x_thunk = @thunk begin # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't # allow `eltype(dy)`, nor does it work for many structured matrices. @@ -373,9 +374,6 @@ for findm in (:findmin, :findmax) end return (NoTangent(), x_ithunk) end - function $findm_pullback(::Tuple{AbstractZero, Any}) - return (NoTangent(), NoTangent()) - end return (y, ind), $findm_pullback end @@ -421,7 +419,10 @@ function _extrema_colon(x) ylo, ilo = findmin(x) yhi, ihi = findmax(x) project = ProjectTo(x) - function extrema_pullback((dylo, dyhi)) + function extrema_pullback((dylo, dyhi)) # accepts Tangent + if (dylo, dyhi) isa Tuple{AbstractZero, AbstractZero} + return (NoTangent(), NoTangent()) + end # One argument may be AbstractZero here. Use promote_op because # promote_type allows for * as well as +, hence gives Any. T = Base.promote_op(+, typeof(dylo), typeof(dyhi)) @@ -439,9 +440,6 @@ function _extrema_colon(x) # end return (NoTangent(), x_nothunk) end - function extrema_pullback(::Tuple{AbstractZero, AbstractZero}) - return (NoTangent(), NoTangent()) - end return (ylo, yhi), extrema_pullback end From 9813783e4f237480d4e9c639ff827071b1b85272 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 2 Aug 2021 23:20:37 -0400 Subject: [PATCH 24/30] allow for second derivatives --- src/rulesets/Base/array.jl | 29 ++++++++++++++++++++--------- test/rulesets/Base/array.jl | 4 ++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index a94d40ca5..ae7e53eb2 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -361,13 +361,7 @@ for findm in (:findmin, :findmax) # This pullback is a lot like the one for getindex. Ideally they would probably be combined? function $findm_pullback((dy, _)) # this accept e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - x_thunk = @thunk begin - # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't - # allow `eltype(dy)`, nor does it work for many structured matrices. - dx = fill!(similar(x, eltype(dy), axes(x)), false) - view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray - project(dx) - end + x_thunk = @thunk project(_writezero(x, dy, ind, dims)) x_ithunk = InplaceableThunk(x_thunk) do dx view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 dx @@ -379,7 +373,24 @@ for findm in (:findmin, :findmax) end -# These rules for maximum pick the same subgradient as findmax: +function _writezero(x, dy, ind, dims) + # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't + # allow `eltype(dy)`, nor does it work for many structured matrices. + dx = fill!(similar(x, eltype(dy), axes(x)), false) + view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray + dx +end + +function rrule(::typeof(_writezero), x, dy, ind, dims) + z = _writezero(x, dy, ind, dims) + _writezero_pullback(dz) = (NoTangent(), NoTangent(), sum(view(unthunk(dz), ind); dims=dims), NoTangent(), NoTangent()) + return z, _writezero_pullback +end + +Base.view(z::AbstractZero, ind...) = z # TODO move to ChainRulesCore +Base.sum(z::AbstractZero; dims=:) = z # TODO move to ChainRulesCore + +# These rules for `maximum` pick the same subgradient as findmax: function frule((_, xdot), ::typeof(maximum), x; dims=:) y, ind = findmax(x; dims=dims) @@ -456,7 +467,7 @@ function _extrema_dims(x, dims) T = Base.promote_op(+, eltype(dy).parameters...) x_nothunk = let # x_thunk = @thunk begin # this doesn't infer - dx = fill!(similar(x, T, axes(x)), false) + dx = fill!(similar(x, T, axes(x)), false) # This won't be twice-differentiable view(dx, ilo) .= first.(dy) view(dx, ihi) .= view(dx, ihi) .+ last.(dy) project(dx) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 6445eac3f..a456a8510 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -221,6 +221,10 @@ end @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") + + # Second derivatives + test_rrule(ChainRules._writezero, [1 2; 3 4], 5, CartesianIndex(2, 2), :) + test_rrule(ChainRules._writezero, [1 2; 3 4], 5, [CartesianIndex(2, 1) CartesianIndex(2, 2)], 1) end @testset "$imum" for imum in [maximum, minimum] From 4039c93c6921b6230453a1b3a89c180a3c979123 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 2 Aug 2021 23:24:57 -0400 Subject: [PATCH 25/30] frule? --- src/rulesets/Base/array.jl | 6 ++++++ test/rulesets/Base/array.jl | 1 + 2 files changed, 7 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index ae7e53eb2..38a11d41c 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -381,6 +381,12 @@ function _writezero(x, dy, ind, dims) dx end +# Allow for second derivatives: + +function frule((_, _, dydot, _, _), ::typeof(_writezero), x, dy, ind, dims) + return _writezero(x, dy, ind, dims), _writezero(x, dydot, ind, dims) +end + function rrule(::typeof(_writezero), x, dy, ind, dims) z = _writezero(x, dy, ind, dims) _writezero_pullback(dz) = (NoTangent(), NoTangent(), sum(view(unthunk(dz), ind); dims=dims), NoTangent(), NoTangent()) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index a456a8510..cb44fe340 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -223,6 +223,7 @@ end @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") # Second derivatives + @test_broken test_frule(ChainRules._writezero, [1 2; 3 4], 5, CartesianIndex(2, 2), :) test_rrule(ChainRules._writezero, [1 2; 3 4], 5, CartesianIndex(2, 2), :) test_rrule(ChainRules._writezero, [1 2; 3 4], 5, [CartesianIndex(2, 1) CartesianIndex(2, 2)], 1) end From 8e07ada6b7ee81eb143c4599e11c2a2e5735d42b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 3 Aug 2021 16:33:28 -0400 Subject: [PATCH 26/30] update to use CRC 1.3 --- src/rulesets/Base/array.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 38a11d41c..aa6184de3 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -381,7 +381,7 @@ function _writezero(x, dy, ind, dims) dx end -# Allow for second derivatives: +# Allow for second derivatives, by writing rules for `_writezero`: function frule((_, _, dydot, _, _), ::typeof(_writezero), x, dy, ind, dims) return _writezero(x, dy, ind, dims), _writezero(x, dydot, ind, dims) @@ -393,10 +393,7 @@ function rrule(::typeof(_writezero), x, dy, ind, dims) return z, _writezero_pullback end -Base.view(z::AbstractZero, ind...) = z # TODO move to ChainRulesCore -Base.sum(z::AbstractZero; dims=:) = z # TODO move to ChainRulesCore - -# These rules for `maximum` pick the same subgradient as findmax: +# These rules for `maximum` pick the same subgradient as `findmax`: function frule((_, xdot), ::typeof(maximum), x; dims=:) y, ind = findmax(x; dims=dims) @@ -424,6 +421,8 @@ end ##### `extrema` ##### +# This won't be twice-differentiable, could do something similar to `_writezero` above. + function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) if dims isa Colon return _extrema_colon(x) @@ -473,7 +472,7 @@ function _extrema_dims(x, dims) T = Base.promote_op(+, eltype(dy).parameters...) x_nothunk = let # x_thunk = @thunk begin # this doesn't infer - dx = fill!(similar(x, T, axes(x)), false) # This won't be twice-differentiable + dx = fill!(similar(x, T, axes(x)), false) view(dx, ilo) .= first.(dy) view(dx, ihi) .= view(dx, ihi) .+ last.(dy) project(dx) From 1722ca53dab50dd2287abfd9ce3105e325e3c814 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 5 Aug 2021 13:33:17 -0400 Subject: [PATCH 27/30] better writezero? --- src/rulesets/Base/array.jl | 31 +++++++++++++++++-------------- test/rulesets/Base/array.jl | 9 ++++++--- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index aa6184de3..4ae4e2403 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -361,7 +361,7 @@ for findm in (:findmin, :findmax) # This pullback is a lot like the one for getindex. Ideally they would probably be combined? function $findm_pullback((dy, _)) # this accept e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - x_thunk = @thunk project(_writezero(x, dy, ind, dims)) + x_thunk = @thunk project(_zerolike_writeat(x, dy, dims, ind)) x_ithunk = InplaceableThunk(x_thunk) do dx view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 dx @@ -370,27 +370,32 @@ for findm in (:findmin, :findmax) end return (y, ind), $findm_pullback end - end -function _writezero(x, dy, ind, dims) +# This is roughly `setindex!(zero(x), dy, inds...)` +function _zerolike_writeat(x, dy, dims, inds...) # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't # allow `eltype(dy)`, nor does it work for many structured matrices. - dx = fill!(similar(x, eltype(dy), axes(x)), false) - view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray + dx = fill!(similar(x, eltype(dy), axes(x)), false) # zero(eltype(dy))) + view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray dx end -# Allow for second derivatives, by writing rules for `_writezero`: +# Allow for second derivatives, by writing rules for `_zerolike_writeat`; +# these rules are the reason it takes a `dims` argument. -function frule((_, _, dydot, _, _), ::typeof(_writezero), x, dy, ind, dims) - return _writezero(x, dy, ind, dims), _writezero(x, dydot, ind, dims) +function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...) + return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...) end -function rrule(::typeof(_writezero), x, dy, ind, dims) - z = _writezero(x, dy, ind, dims) - _writezero_pullback(dz) = (NoTangent(), NoTangent(), sum(view(unthunk(dz), ind); dims=dims), NoTangent(), NoTangent()) - return z, _writezero_pullback +function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...) + z = _zerolike_writeat(x, dy, dims, inds...) + function _zerolike_writeat_pullback(dz) + dx = sum(view(unthunk(dz), inds...); dims=dims) + nots = map(_ -> NoTangent(), inds) + return (NoTangent(), NoTangent(), dx, NoTangent(), nots...) + end + return z, _zerolike_writeat_pullback end # These rules for `maximum` pick the same subgradient as `findmax`: @@ -421,8 +426,6 @@ end ##### `extrema` ##### -# This won't be twice-differentiable, could do something similar to `_writezero` above. - function rrule(::typeof(extrema), x::AbstractArray{<:Number}; dims=:) if dims isa Colon return _extrema_colon(x) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index cb44fe340..ac41fe74e 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -223,9 +223,12 @@ end @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") # Second derivatives - @test_broken test_frule(ChainRules._writezero, [1 2; 3 4], 5, CartesianIndex(2, 2), :) - test_rrule(ChainRules._writezero, [1 2; 3 4], 5, CartesianIndex(2, 2), :) - test_rrule(ChainRules._writezero, [1 2; 3 4], 5, [CartesianIndex(2, 1) CartesianIndex(2, 2)], 1) + test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) + test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) + @test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching +(::Float64, ::Matrix{Float64}) + y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)]) + @test y == [0 0; 5 5] + @test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent()) end @testset "$imum" for imum in [maximum, minimum] From 656602aa6f0e179765e008d27a72b45956bbf9cc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Nov 2021 10:13:36 -0500 Subject: [PATCH 28/30] fix tests --- test/rulesets/Base/array.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index ac41fe74e..6acddfe5e 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -205,27 +205,29 @@ end test_frule(findmax, rand(10)) @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4))) isa Tuple{Tuple{Float64, CartesianIndex}, Tangent} @test @inferred(frule((nothing, rand(3,4)), findmin, rand(3,4), dims=1)) isa Tuple{Tuple{Matrix, Matrix}, Tangent} - @test_skip test_frule(findmin, rand(3,4)) # StackOverflowError, CartesianIndex{2}(index::Tuple{Float64, Float64}) (repeats 79984 times) & TypeError: in new, expected Tuple{Int64, Int64}, got a value of type Tuple{Float64, Float64} + @test_skip test_frule(findmin, rand(3,4)) # error from test_approx(actual::CartesianIndex{2}, expected::CartesianIndex{2} @test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent())) @test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,)) + # These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 # Reverse test_rrule(findmin, rand(10), output_tangent = (rand(), false)) test_rrule(findmax, rand(10), output_tangent = (rand(), false)) + test_rrule(findmin, rand(5,3)) + test_rrule(findmax, rand(5,3)) @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2]) - @test_skip test_rrule(findmin, rand(10), output_tangent = (rand(), NoTangent()), check_inferred=false) # DimensionMismatch from FiniteDifferences - @test_skip test_rrule(findmax, rand(5,3), output_tangent = (rand(), false), check_inferred=false) # DimensionMismatch from FiniteDifferences - + @test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2]) + # Reverse with dims: @test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2]) @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) - @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()), check_inferred=false) # DimensionMismatch("second dimension of A, 12, does not match length of x, 5"), wtf? - @test_skip test_rrule(findmin, rand(3,4), fkwargs=(dims=2,), output_tangent = (rand(3,1), falses(3,1)), check_inferred=false) # DimensionMismatch("second dimension of A, 9, does not match length of x, 7") + test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent())) + test_rrule(findmin, rand(3,4), fkwargs=(dims=2,)) # Second derivatives test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2)) - @test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching +(::Float64, ::Matrix{Float64}) + @test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] ⊢ NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9) y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)]) @test y == [0 0; 5 5] @test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent()) @@ -249,7 +251,7 @@ end # Structured matrix -- NB the minimum is a structral zero here @test unthunk(rrule(imum, Diagonal(rand(3) .+ 1))[2](5.5)[2]) isa Diagonal - @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular + @test unthunk(rrule(imum, UpperTriangular(rand(3,3) .+ 1))[2](5.5)[2]) isa UpperTriangular{Float64} @test_skip test_rrule(imum, Diagonal(rand(3) .+ 1)) # MethodError: no method matching zero(::Type{Any}), from fill!(A::SparseArrays.SparseMatrixCSC{Any, Int64}, x::Bool) end From 77d7526d5ef4ec74aaab997fd1c19fb64c3d327a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Nov 2021 10:31:45 -0500 Subject: [PATCH 29/30] allow arrays of arrays --- src/rulesets/Base/array.jl | 33 ++++++++++++++++++++++++--------- test/rulesets/Base/array.jl | 6 ++++++ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 4ae4e2403..63ce70ae0 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -355,15 +355,19 @@ for findm in (:findmin, :findmax) return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent()) end - @eval function rrule(::typeof($findm), x::AbstractArray{<:Number}; dims=:) + @eval function rrule(::typeof($findm), x::AbstractArray; dims=:) y, ind = $findm(x; dims=dims) project = ProjectTo(x) # This pullback is a lot like the one for getindex. Ideally they would probably be combined? - function $findm_pullback((dy, _)) # this accept e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) + function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing) dy isa AbstractZero && return (NoTangent(), NoTangent()) - x_thunk = @thunk project(_zerolike_writeat(x, dy, dims, ind)) + x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind)) x_ithunk = InplaceableThunk(x_thunk) do dx - view(dx, ind) .= view(dx, ind) .+ dy # this could be .+=, but not on Julia 1.0 + if dims isa Colon + view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy)) + else + view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0 + end dx end return (NoTangent(), x_ithunk) @@ -372,14 +376,25 @@ for findm in (:findmin, :findmax) end end -# This is roughly `setindex!(zero(x), dy, inds...)` -function _zerolike_writeat(x, dy, dims, inds...) +# This function is roughly `setindex!(zero(x), dy, inds...)`: + +function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...) # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't # allow `eltype(dy)`, nor does it work for many structured matrices. - dx = fill!(similar(x, eltype(dy), axes(x)), false) # zero(eltype(dy))) + dx = fill!(similar(x, eltype(dy), axes(x)), 0) view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray dx end +function _zerolike_writeat(x::AbstractArray, dy, dims, inds...) + # Since we have `x`, we can also handle arrays of arrays. + dx = map(zero, x) + if dims isa Colon + view(dx, inds...) .= Ref(dy) + else + view(dx, inds...) .= dy + end + dx +end # Allow for second derivatives, by writing rules for `_zerolike_writeat`; # these rules are the reason it takes a `dims` argument. @@ -405,7 +420,7 @@ function frule((_, xdot), ::typeof(maximum), x; dims=:) return y, xdot[ind] end -function rrule(::typeof(maximum), x::AbstractArray{<:Number}; dims=:) +function rrule(::typeof(maximum), x::AbstractArray; dims=:) (y, _), back = rrule(findmax, x; dims=dims) maximum_pullback(dy) = back((dy, nothing)) return y, maximum_pullback @@ -416,7 +431,7 @@ function frule((_, xdot), ::typeof(minimum), x; dims=:) return y, xdot[ind] end -function rrule(::typeof(minimum), x::AbstractArray{<:Number}; dims=:) +function rrule(::typeof(minimum), x::AbstractArray; dims=:) (y, _), back = rrule(findmin, x; dims=dims) minimum_pullback(dy) = back((dy, nothing)) return y, minimum_pullback diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 6acddfe5e..ff3e93931 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -238,6 +238,8 @@ end test_frule(imum, rand(10)) test_frule(imum, rand(3,4)) test_frule(imum, rand(3,4), fkwargs=(dims=1,)) + test_frule(imum, [rand(2) for _ in 1:3]) + test_frule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,)) # Reverse test_rrule(imum, rand(10)) @@ -245,6 +247,10 @@ end test_rrule(imum, rand(3,4), fkwargs=(dims=1,)) test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),)) + # Arrays of arrays + test_rrule(imum, [rand(2) for _ in 1:3]; check_inferred=false) + test_rrule(imum, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(dims=1,), check_inferred=false) + # Case which attains max twice -- can't use FiniteDifferences for this res = imum == maximum ? [0,1,0,0,0,0] : [1,0,0,0,0,0] @test res == @inferred unthunk(rrule(imum, [1,2,1,2,1,2])[2](1.0)[2]) From b17d91e1442ce66347267e1474824c4d2a30b9c8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Nov 2021 10:40:00 -0500 Subject: [PATCH 30/30] version --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 1274201c2..b36f894cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.13.0" +version = "1.14.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,10 +11,10 @@ RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "1.10" +ChainRulesCore = "1.11" ChainRulesTestUtils = "1" Compat = "3.35" -FiniteDifferences = "0.12.8" +FiniteDifferences = "0.12.20" JuliaInterpreter = "0.8" RealDot = "0.1" StaticArrays = "1.2"