From 366d98319220e811fd857321d02cd314b06601f0 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 27 Aug 2019 19:17:21 +0100 Subject: [PATCH 01/38] [WIP] include dervative WRT self. Scalar functions changed over --- Project.toml | 4 +-- test/rulesets/Base/base.jl | 57 +++++++++++++++++++++++--------------- test/test_util.jl | 15 +++++++++- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/Project.toml b/Project.toml index acc564c2d..84f3ba69f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.1.1" +version = "0.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "^0.2" +ChainRulesCore = "^0.3" FiniteDifferences = "^0.7" julia = "^1.0" diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a4d0d4bd8..a02617444 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -52,25 +52,37 @@ test_scalar(acotd, 1/x) end @testset "Multivariate" begin - x, y = rand(2) - ratan = atan(x, y) # https://en.wikipedia.org/wiki/Atan2 - u = x^2 + y^2 - datan = y/u - 2x/u - r, df = frule(atan, x, y) - @test r === ratan - @test df(1, 2) === datan - r, (df1, df2) = rrule(atan, x, y) - @test r === ratan - @test df1(1) + df2(2) === datan - - rsincos = sincos(x) - dsincos = cos(x) - 2sin(x) - r, (df1, df2) = frule(sincos, x) - @test r === rsincos - @test df1(1) + df2(2) === dsincos - r, df = rrule(sincos, x) - @test r === rsincos - @test df(1, 2) === dsincos + x, y = rand(2)i + @testset "atan2" begin + ratan = atan(x, y) # https://en.wikipedia.org/wiki/Atan2 + u = x^2 + y^2 + datan = y/u - 2x/u + + r, (ds, df) = frule(atan, x, y) + @test r === ratan + @test df(1, 2) === datan + @test ds === NO_FIELDS_RULE + + r, (ds, df1, df2) = rrule(atan, x, y) + @test r === ratan + @test ds === NO_FIELDS_RULE + @test df1(1) + df2(2) === datan + end + + @testset "sincos" begin + rsincos = sincos(x) + dsincos = cos(x) - 2sin(x) + + r, (ds, df1, df2) = frule(sincos, x) + @test r === rsincos + @test df1(1) + df2(2) === dsincos + @test ds === NO_FIELDS_RULE + + r, (ds, df) = rrule(sincos, x) + @test r === rsincos + @test df(1, 2) === dsincos + @test ds === NO_FIELDS_RULE + end end end # Trig @@ -116,12 +128,12 @@ @testset "*(x, y)" begin x, y = rand(3, 2), rand(2, 5) - z, (dx, dy) = rrule(*, x, y) + z, (ds, dx, dy) = rrule(*, x, y) @test z == x * y z̄ = rand(3, 5) - + @test ds === NO_FIELDS_RULE @test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄)) @test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄)) @@ -131,8 +143,9 @@ @testset "hypot(x, y)" begin x, y = rand(2) - h, dxy = frule(hypot, x, y) + h, (ds, dxy) = frule(hypot, x, y) + @test ds === NO_FIELDS_RULE @test extern(dxy(One(), Zero())) === x / h @test extern(dxy(Zero(), One())) === y / h diff --git a/test/test_util.jl b/test/test_util.jl index 571b9c36f..cc547ce21 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -19,13 +19,26 @@ at input point `x` to confirm that there are correct ChainRules provided. All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`. """ +<<<<<<< HEAD function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...) +======= +function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) + if fieldcount(typeof(f)) > 0 + throw(ArgumentError( + "test_scalar cannot be used on closures/functors (such as $f)" + )) + end + +>>>>>>> [WIP] include dervative WRT self. @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) res = rule(f, x) @test res !== nothing # Check the rule was defined - fx, ∂x = res + fx, ∂s = res @test fx == f(x) # Check we still get the normal value, right + ∂self, ∂x = ∂s + @test ∂self === NamedTuple() # No internal fields + # Check that we get the derivative right: if !test_wirtinger @test isapprox( From 4e9e3302be4744446cd13b2b6e258d543f82c9ea Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 28 Aug 2019 12:57:53 +0100 Subject: [PATCH 02/38] wip --- test/rulesets/Base/base.jl | 2 +- test/test_util.jl | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a02617444..5ae828411 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -52,7 +52,7 @@ test_scalar(acotd, 1/x) end @testset "Multivariate" begin - x, y = rand(2)i + x, y = rand(2) @testset "atan2" begin ratan = atan(x, y) # https://en.wikipedia.org/wiki/Atan2 u = x^2 + y^2 diff --git a/test/test_util.jl b/test/test_util.jl index cc547ce21..9672d6669 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -19,25 +19,20 @@ at input point `x` to confirm that there are correct ChainRules provided. All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`. """ -<<<<<<< HEAD function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...) -======= -function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) if fieldcount(typeof(f)) > 0 throw(ArgumentError( "test_scalar cannot be used on closures/functors (such as $f)" )) end ->>>>>>> [WIP] include dervative WRT self. @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) res = rule(f, x) @test res !== nothing # Check the rule was defined - fx, ∂s = res + fx, (∂self_rule, ∂x_rule) = res @test fx == f(x) # Check we still get the normal value, right - ∂self, ∂x = ∂s - @test ∂self === NamedTuple() # No internal fields + @test ∂self_rule === NO_FIELDS_RULE # No internal fields # Check that we get the derivative right: if !test_wirtinger @@ -64,6 +59,7 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) end + """ frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) @@ -80,13 +76,18 @@ end function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) xs, ẋs = collect(zip(xẋs...)) - Ω, dΩ_rule = ChainRules.frule(f, xs...) + Ω, (∂self_rule, dΩ_rule) = ChainRules.frule(f, xs...) @test f(xs...) == Ω - dΩ_ad, dΩ_fd = dΩ_rule(ẋs...), jvp(fdm, xs->f(xs...), (xs, ẋs)) + @test ∂self_rule === NO_FIELDS_RULE # No internal fields + + dΩ_ad = dΩ_rule(ẋs...) + dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) @test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...) end +fooo⃖ +foo⃡ """ rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) From 4f641d6d120215910326cb22ec65bb653d9bac94 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 28 Aug 2019 16:16:51 +0100 Subject: [PATCH 03/38] WIP --- src/rulesets/Base/base.jl | 10 ++++------ test/rulesets/Base/base.jl | 7 +++---- test/test_util.jl | 21 +++++++++++++-------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index b6f67f5d6..c679087d3 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -103,10 +103,8 @@ # product rule requires special care for arguments where `mul` is non-commutative -frule(::typeof(*), x::Number, y::Number) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy) +frule(::typeof(*), x::Number, y::Number) = x * y, (ZERO_RULE, Rule((Δx, Δy) -> Δx * y + x * Δy)) +rrule(::typeof(*), x::Number, y::Number) = x * y, (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) -rrule(::typeof(*), x::Number, y::Number) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) - -frule(::typeof(identity), x) = x, Rule(identity) - -rrule(::typeof(identity), x) = x, Rule(identity) +frule(::typeof(identity), x) = x, (ZERO_RULE, Rule(identity)) +rrule(::typeof(identity), x) = x, (NO_FIELDS_RULE, Rule(identity)) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 5ae828411..237ab33e2 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -61,7 +61,7 @@ r, (ds, df) = frule(atan, x, y) @test r === ratan @test df(1, 2) === datan - @test ds === NO_FIELDS_RULE + @test ds === ZERO_RULE r, (ds, df1, df2) = rrule(atan, x, y) @test r === ratan @@ -76,7 +76,7 @@ r, (ds, df1, df2) = frule(sincos, x) @test r === rsincos @test df1(1) + df2(2) === dsincos - @test ds === NO_FIELDS_RULE + @test ds === ZERO_RULE r, (ds, df) = rrule(sincos, x) @test r === rsincos @@ -145,7 +145,7 @@ x, y = rand(2) h, (ds, dxy) = frule(hypot, x, y) - @test ds === NO_FIELDS_RULE + @test ds === ZERO_RULE @test extern(dxy(One(), Zero())) === x / h @test extern(dxy(Zero(), One())) === y / h @@ -162,7 +162,6 @@ @testset "identity" begin rng = MersenneTwister(1) - n = 4 rrule_test(identity, randn(rng), (randn(rng), randn(rng))) rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4))) end diff --git a/test/test_util.jl b/test/test_util.jl index 9672d6669..d4bad598f 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -32,7 +32,9 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa fx, (∂self_rule, ∂x_rule) = res @test fx == f(x) # Check we still get the normal value, right - @test ∂self_rule === NO_FIELDS_RULE # No internal fields + # No internal fields + rule===rrule && @test ∂self_rule === NO_FIELDS_RULE + rule===frule && @test ∂self_rule === ZERO_RULE # Check that we get the derivative right: if !test_wirtinger @@ -79,15 +81,15 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm Ω, (∂self_rule, dΩ_rule) = ChainRules.frule(f, xs...) @test f(xs...) == Ω - @test ∂self_rule === NO_FIELDS_RULE # No internal fields + @test ∂self_rule === ZERO_RULE # No internal fields + # Correctness testing via finite differencing. dΩ_ad = dΩ_rule(ẋs...) dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) @test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...) end -fooo⃖ -foo⃡ + """ rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) @@ -101,16 +103,19 @@ All keyword arguments except for `fdm` are passed to `isapprox`. """ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) # Check correctness of evaluation. - fx, dx = ChainRules.rrule(f, x) + fx, (∂self_rule, dx_rule) = ChainRules.rrule(f, x) @test fx ≈ f(x) + @test ∂self_rule === NO_FIELDS_RULE # No internal fields + # Correctness testing via finite differencing. - x̄_ad, x̄_fd = dx(ȳ), j′vp(fdm, f, ȳ, x) + x̄_ad = dx_rule(ȳ) + x̄_fd = j′vp(fdm, f, ȳ, x) @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) # Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct. - test_accumulation(x̄, dx, ȳ, x̄_ad) - test_accumulation(Zero(), dx, ȳ, x̄_ad) + test_accumulation(x̄, dx_rule, ȳ, x̄_ad) + test_accumulation(Zero(), dx_rule, ȳ, x̄_ad) end function _make_fdm_call(fdm, f, ȳ, xs, ignores) From c5758954d7fbe4358d7be94a43569b656973d6fb Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 28 Aug 2019 18:30:40 +0100 Subject: [PATCH 04/38] [WIP] make changes to all the rules to return WRT self Temp comment out all accumulate related tests --- src/rulesets/Base/array.jl | 16 +++--- src/rulesets/Base/broadcast.jl | 4 +- src/rulesets/Base/mapreduce.jl | 17 ++++--- src/rulesets/LinearAlgebra/blas.jl | 26 +++++----- src/rulesets/LinearAlgebra/dense.jl | 35 +++++++------ src/rulesets/LinearAlgebra/factorization.jl | 10 ++-- src/rulesets/LinearAlgebra/structured.jl | 26 +++++----- test/test_util.jl | 55 ++++++++++++--------- 8 files changed, 103 insertions(+), 86 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b071f0447..b66209fc1 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -3,12 +3,14 @@ ##### function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) - return reshape(A, dims), (Rule(Ȳ->reshape(Ȳ, dims)), DNERule()) + return reshape(A, dims), (NO_FIELDS_RULE, Rule(Ȳ->reshape(Ȳ, dims)), DNERule()) end function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) - Y, (rule, _) = rrule(reshape, A, dims) - return Y, (rule, fill(DNERule(), length(dims))...) + Y, (nofields, rule, dne) = rrule(reshape, A, dims)[2] + @assert no_fields === NO_FIELDS_RULE + @assert dne === DNERule() + return Y, (NO_FIELDS_RULE, rule, fill(DNERule(), length(dims))...) end ##### @@ -26,7 +28,7 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) # materialize with `copy` Rule(Ȳ->copy(selectdim(Ȳ, 2, dim))) end - return Y, rules + return Y, (NO_FIELDS_RULE, rules...) end ##### @@ -42,7 +44,7 @@ function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) u = l + size(Bs[i], 1) Rule(Ȳ->copy(selectdim(Ȳ, 1, l+1:u))) end - return Y, (∂A, ∂Bs...) + return Y, (NO_FIELDS_RULE, ∂A, ∂Bs...) end ##### @@ -50,9 +52,9 @@ end ##### function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) - return fill(value, dims), (Rule(sum), DNERule()) + return fill(value, dims), (NO_FIELDS_RULE, Rule(sum), DNERule()) end function rrule(::typeof(fill), value::Any, dims::Int...) - return fill(value, dims), (Rule(sum), ntuple(_->DNERule(), length(dims))...) + return fill(value, dims), (NO_FIELDS_RULE, Rule(sum), ntuple(_->DNERule(), length(dims))...) end diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index f3685f5ca..c57ad8c59 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -15,10 +15,10 @@ end function frule(::typeof(broadcast), f, x) Ω, ∂x = _cast_diff(f, x) - return Ω, Rule((_, Δx) -> Δx * cast(∂x)) + return Ω, (ZERO_RULE, Rule((_, Δx) -> Δx * cast(∂x))) end function rrule(::typeof(broadcast), f, x) values, derivs = _cast_diff(f, x) - return values, (DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs))) + return values, (NO_FIELDS_RULE, DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs))) end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index b0e6a006b..7a01823b0 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -12,7 +12,7 @@ function rrule(::typeof(map), f, xs...) end end end - return y, (DNERule(), ∂xs...) + return y, (NO_FIELDS_RULE, DNERule(), ∂xs...) end ##### @@ -34,7 +34,7 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr) extern(∂xi(ȳi)) end end - return y, (DNERule(), DNERule(), ∂x) + return y, (NO_FIELDS_RULE, DNERule(), DNERule(), ∂x) end eval(Expr(:function, sig, body)) end @@ -43,22 +43,23 @@ end ##### `sum` ##### -frule(::typeof(sum), x) = (sum(x), Rule(sum)) +frule(::typeof(sum), x) = (sum(x), (ZERO_RULE, Rule(sum))) -rrule(::typeof(sum), x) = (sum(x), Rule(cast)) +rrule(::typeof(sum), x) = (sum(x), (NO_FIELDS_RULE, Rule(cast))) function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims) - return y, (DNERule(), ∂x) + return y, (NO_FIELDS_RULE, DNERule(), ∂x) end function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:) - y, (_, ∂x) = rrule(sum, identity, x; dims=dims) - return y, ∂x + y, (no_fields, _, ∂x) = rrule(sum, identity, x; dims=dims) + @assert(no_fields === NO_FIELDS_RULE) + return y, (NO_FIELDS_RULE, ∂x) end function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) y = sum(abs2, x; dims=dims) ∂x = Rule(ȳ -> 2ȳ .* x) - return y, (DNERule(), ∂x) + return y, (NO_FIELDS_RULE, DNERule(), ∂x) end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index fb5b23f4b..9f873aedd 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -21,7 +21,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) Ω = BLAS.dot(n, X, incx, Y, incy) ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) - return Ω, (DNERule(), _rule_via(∂X), DNERule(), _rule_via(∂Y), DNERule()) + return Ω, (NO_FIELDS_RULE, DNERule(), _rule_via(∂X), DNERule(), _rule_via(∂Y), DNERule()) end ##### @@ -30,32 +30,36 @@ end function frule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) - return Ω, Rule(Δx -> sum(Δx * cast(@thunk(x * inv(Ω))))) + return Ω, (ZERO_RULE, Rule(Δx -> sum(Δx * cast(@thunk(x * inv(Ω)))))) end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) - return Ω, Rule(ΔΩ -> ΔΩ * @thunk(x * inv(Ω))) + return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * @thunk(x * inv(Ω)))) end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) ∂X = ΔΩ -> scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) - return Ω, (DNERule(), _rule_via(∂X), DNERule()) + return Ω, (NO_FIELDS_RULE, DNERule(), _rule_via(∂X), DNERule()) end ##### ##### `BLAS.asum` ##### -frule(::typeof(BLAS.asum), x) = (BLAS.asum(x), Rule(Δx -> sum(cast(sign, x) * Δx))) +function frule(::typeof(BLAS.asum), x) + return BLAS.asum(x), (ZERO_RULE, Rule(Δx -> sum(cast(sign, x) * Δx))) +end -rrule(::typeof(BLAS.asum), x) = (BLAS.asum(x), Rule(ΔΩ -> ΔΩ * cast(sign, x))) +function rrule(::typeof(BLAS.asum), x) + return BLAS.asum(x), (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * cast(sign, x))) +end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, sign.(X), incx, _zeros(X), incx), incx) - return Ω, (DNERule(), _rule_via(∂X), DNERule()) + return Ω, (NO_FIELDS_RULE, DNERule(), _rule_via(∂X), DNERule()) end ##### @@ -72,13 +76,13 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, ∂A = Rule(ȳ -> α * x * ȳ', (Ā, ȳ) -> ger!(α, x, ȳ, Ā)) ∂x = Rule(ȳ -> gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄)) end - return y, (DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x) + return y, (NO_FIELDS_RULE, DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x) end function rrule(::typeof(gemv), tA::Char, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat y, (dtA, _, dA, dx) = rrule(gemv, tA, one(T), A, x) - return y, (dtA, dA, dx) + return y, (NO_FIELDS_RULE, dtA, dA, dx) end ##### @@ -114,11 +118,11 @@ function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, (B̄, C̄) -> gemm!('T', 'T', α, C̄, A, β, B̄)) end end - return C, (DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B) + return C, (NO_FIELDS_RULE, DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B) end function rrule(::typeof(gemm), tA::Char, tB::Char, A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B) - return C, (dtA, dtB, dA, dB) + return C, (NO_FIELDS_RULE, dtA, dtB, dA, dB) end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 9eb3ee168..010a7dce1 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -9,11 +9,11 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} ##### function frule(::typeof(dot), x, y) - return dot(x, y), Rule((Δx, Δy) -> sum(Δx * cast(y)) + sum(cast(x) * Δy)) + return dot(x, y), (ZERO_RULE, Rule((Δx, Δy) -> sum(Δx * cast(y)) + sum(cast(x) * Δy))) end function rrule(::typeof(dot), x, y) - return dot(x, y), (Rule(ΔΩ -> ΔΩ * cast(y)), Rule(ΔΩ -> cast(x) * ΔΩ)) + return dot(x, y), (NO_FIELDS_RULE, (Rule(ΔΩ -> ΔΩ * cast(y)), Rule(ΔΩ -> cast(x) * ΔΩ))) end ##### @@ -23,13 +23,13 @@ end function frule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω) - return Ω, Rule(Δx -> m * Δx * Ω) + return Ω, (ZERO_RULE, Rule(Δx -> m * Δx * Ω)) end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω') - return Ω, Rule(ΔΩ -> m * ΔΩ * Ω') + return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> m * ΔΩ * Ω')) end ##### @@ -38,12 +38,12 @@ end function frule(::typeof(det), x) Ω, m = det(x), @thunk(inv(x)) - return Ω, Rule(Δx -> Ω * tr(extern(m * Δx))) + return Ω, (ZERO_RULE, Rule(Δx -> Ω * tr(extern(m * Δx)))) end function rrule(::typeof(det), x) Ω, m = det(x), @thunk(inv(x)') - return Ω, Rule(ΔΩ -> Ω * ΔΩ * m) + return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> Ω * ΔΩ * m)) end ##### @@ -52,28 +52,27 @@ end function frule(::typeof(logdet), x) Ω, m = logdet(x), @thunk(inv(x)) - return Ω, Rule(Δx -> tr(extern(m * Δx))) + return Ω, (ZERO_RULE, Rule(Δx -> tr(extern(m * Δx)))) end function rrule(::typeof(logdet), x) Ω, m = logdet(x), @thunk(inv(x)') - return Ω, Rule(ΔΩ -> ΔΩ * m) + return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * m)) end ##### ##### `trace` ##### -frule(::typeof(tr), x) = (tr(x), Rule(Δx -> tr(extern(Δx)))) - -rrule(::typeof(tr), x) = (tr(x), Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1))))) +frule(::typeof(tr), x) = (tr(x), (ZERO_RULE, Rule(Δx -> tr(extern(Δx))))) +rrule(::typeof(tr), x) = (tr(x), (NO_FIELDS_RULE, Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1)))))) ##### ##### `*` ##### function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) - return A * B, (Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) + return A * B, (NO_FIELDS_RULE, Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) end ##### @@ -85,7 +84,7 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr S = T.name.wrapper ∂A = Rule(Ȳ -> Ȳ / B') ∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B'))) - return Y, (∂A, ∂B) + return Y, (NO_FIELDS_RULE, ∂A, ∂B) end function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) @@ -95,7 +94,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R C, dC = rrule(adjoint, Cᵀ) ∂A = Rule(dA∘dAᵀ∘dC) ∂B = Rule(dA∘dBᵀ∘dC) - return C, (∂A, ∂B) + return C, (NO_FIELDS_RULE, ∂A, ∂B) end ##### @@ -107,7 +106,7 @@ function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMa S = T.name.wrapper ∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y')) ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (∂A, ∂B) + return Y, (NO_FIELDS_RULE, ∂A, ∂B) end function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) @@ -120,7 +119,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R Ā end ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (∂A, ∂B) + return Y, (NO_FIELDS_RULE, ∂A, ∂B) end ##### @@ -132,9 +131,9 @@ function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2) u = y^(1-p) ∂A = Rule(ȳ -> ȳ .* u .* abs.(A).^p ./ A) ∂p = Rule(ȳ -> ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p) - return y, (∂A, ∂p) + return y, (NO_FIELDS_RULE, ∂A, ∂p) end function rrule(::typeof(norm), x::Real, p::Real=2) - return norm(x, p), (Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) + return norm(x, p), (NO_FIELDS_RULE, Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 72527fcc6..6483711b0 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -10,7 +10,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) ∂X = Rule() do Ȳ::NamedTuple{(:U,:S,:V)} svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) end - return F, ∂X + return F, (NO_FIELDS_RULE, ∂X) end function rrule(::typeof(getproperty), F::SVD, x::Symbol) @@ -25,7 +25,7 @@ function rrule(::typeof(getproperty), F::SVD, x::Symbol) throw(ArgumentError("Vt is unsupported; use V and transpose the result")) end update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x) - return getproperty(F, x), (Rule(rule, update), DNERule()) + return getproperty(F, x), (NO_FIELDS_RULE, Rule(rule, update), DNERule()) end function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) @@ -66,7 +66,7 @@ end function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) F = cholesky(X) ∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) - return F, ∂X + return F, (NO_FIELDS_RULE, ∂X) end function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) @@ -83,7 +83,7 @@ function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) ∂F = Ȳ->UpperTriangular(Ȳ') end end - return getproperty(F, x), (Rule(∂F), DNERule()) + return getproperty(F, x), (NO_FIELDS_RULE, Rule(∂F), DNERule()) end # See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, @@ -184,7 +184,7 @@ end """ chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) -Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly +Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index d2ee20309..9e98b9d99 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -4,15 +4,15 @@ ##### `Diagonal` ##### -rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag) +rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), (NO_FIELDS_RULE, Rule(diag)) -rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal) +rrule(::typeof(diag), A::AbstractMatrix) = diag(A), (NO_FIELDS_RULE, Rule(Diagonal)) ##### ##### `Symmetric` ##### -rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) +rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), (NO_FIELDS_RULE, Rule(_symmetric_back)) _symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ @@ -22,26 +22,26 @@ _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### # TODO: Deal with complex-valued arrays as well -rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), Rule(adjoint) -rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), Rule(vec∘adjoint) +rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), (NO_FIELDS_RULE, Rule(adjoint)) +rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), (NO_FIELDS_RULE, Rule(vec∘adjoint)) -rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), Rule(adjoint) -rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), Rule(vec∘adjoint) +rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), (NO_FIELDS_RULE, Rule(adjoint)) +rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), (NO_FIELDS_RULE, Rule(vec∘adjoint)) ##### ##### `Transpose` ##### -rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), Rule(transpose) -rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), Rule(vec∘transpose) +rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), (NO_FIELDS_RULE, Rule(transpose)) +rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), (NO_FIELDS_RULE, Rule(vec∘transpose)) -rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), Rule(transpose) -rrule(::typeof(transpose), A::AbstractVector) = transpose(A), Rule(vec∘transpose) +rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), (NO_FIELDS_RULE, Rule(transpose)) +rrule(::typeof(transpose), A::AbstractVector) = transpose(A), (NO_FIELDS_RULE, Rule(vec∘transpose)) ##### ##### Triangular matrices ##### -rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix) +rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), (NO_FIELDS_RULE, Rule(Matrix)) -rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix) +rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), (NO_FIELDS_RULE, Rule(Matrix)) diff --git a/test/test_util.jl b/test/test_util.jl index d4bad598f..ff8fd2267 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -20,11 +20,7 @@ at input point `x` to confirm that there are correct ChainRules provided. All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`. """ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...) - if fieldcount(typeof(f)) > 0 - throw(ArgumentError( - "test_scalar cannot be used on closures/functors (such as $f)" - )) - end + ensure_not_running_on_functor(f, "test_scalar") @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) res = rule(f, x) @@ -39,7 +35,7 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa # Check that we get the derivative right: if !test_wirtinger @test isapprox( - ∂x(1), fdm(f, x); + ∂x_rule(1), fdm(f, x); rtol=rtol, atol=atol, kwargs... ) else @@ -49,18 +45,24 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa ∂ = 0.5(∂Re - im*∂Im) ∂̅ = 0.5(∂Re + im*∂Im) @test isapprox( - wirtinger_primal(∂x(1)), ∂; + wirtinger_primal(∂x_rule(1)), ∂; rtol=rtol, atol=atol, kwargs... ) @test isapprox( - wirtinger_conjugate(∂x(1)), ∂̅; + wirtinger_conjugate(∂x_rule(1)), ∂̅; rtol=rtol, atol=atol, kwargs... ) end end end - +function ensure_not_running_on_functor(f, name) + if fieldcount(typeof(f)) > 0 + throw(ArgumentError( + "$name cannot be used on closures/functors (such as $f)" + )) + end +end """ frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) @@ -77,6 +79,7 @@ function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) end function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) + ensure_not_running_on_functor(f, "frule_test") xs, ẋs = collect(zip(xẋs...)) Ω, (∂self_rule, dΩ_rule) = ChainRules.frule(f, xs...) @test f(xs...) == Ω @@ -102,6 +105,8 @@ end All keyword arguments except for `fdm` are passed to `isapprox`. """ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) + ensure_not_running_on_functor(f, "rrule_test") + # Check correctness of evaluation. fx, (∂self_rule, dx_rule) = ChainRules.rrule(f, x) @test fx ≈ f(x) @@ -146,13 +151,19 @@ function _make_fdm_call(fdm, f, ȳ, xs, ignores) end function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) + ensure_not_running_on_functor(f, "rrule_test") + # Check correctness of evaluation. xs, x̄s = collect(zip(xx̄s...)) y, rules = rrule(f, xs...) @test f(xs...) == y + self_rule = rules[1] + arg_rules = rules[2:end] + @test self_rule === NO_FIELDS_RULE + # Correctness testing via finite differencing. - x̄s_ad = map(rules) do rule + x̄s_ad = map(arg_rules) do rule rule isa DNERule ? DNE() : rule(ȳ) end x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing) @@ -168,8 +179,8 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm # Assuming the above to be correct, check that other ChainRules mechanisms are correct. for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad) x̄ === nothing && continue - test_accumulation(x̄, rule, ȳ, x̄_ad) - test_accumulation(Zero(), rule, ȳ, x̄_ad) + #test_accumulation(x̄, rule, ȳ, x̄_ad) + #test_accumulation(Zero(), rule, ȳ, x̄_ad) end end @@ -187,7 +198,7 @@ function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) end function test_accumulation(x̄, dx, ȳ, partial) - @test all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) + #@test all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) test_accumulate(x̄, dx, ȳ, partial) test_accumulate!(x̄, dx, ȳ, partial) test_store!(x̄, dx, ȳ, partial) @@ -195,33 +206,33 @@ function test_accumulation(x̄, dx, ȳ, partial) end function test_accumulate(x̄::Zero, dx, ȳ, partial) - @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) + #@test extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) return nothing end function test_accumulate(x̄::Number, dx, ȳ, partial) - @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) + #@test extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) return nothing end function test_accumulate(x̄::AbstractArray, dx, ȳ, partial) x̄_old = copy(x̄) - @test all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) - @test x̄ == x̄_old + #@test all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) + #@test x̄ == x̄_old return nothing end test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing function test_accumulate!(x̄::Number, dx, ȳ, partial) - @test accumulate!(x̄, dx, ȳ) ≈ accumulate(x̄, dx, ȳ) + #@test accumulate!(x̄, dx, ȳ) ≈ accumulate(x̄, dx, ȳ) return nothing end function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial) x̄_copy = copy(x̄) - accumulate!(x̄_copy, dx, ȳ) - @test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) + #accumulate!(x̄_copy, dx, ȳ) + #@test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) return nothing end @@ -230,7 +241,7 @@ test_store!(x̄::Number, dx, ȳ, partial) = nothing function test_store!(x̄::AbstractArray, dx, ȳ, partial) x̄_copy = copy(x̄) - store!(x̄_copy, dx, ȳ) - @test all(x̄_copy .≈ extern(partial)) + #store!(x̄_copy, dx, ȳ) + #@test all(x̄_copy .≈ extern(partial)) return nothing end From 0624600281743a568711c5d0106c804e0d85a35b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 29 Aug 2019 17:26:50 +0100 Subject: [PATCH 05/38] comment out more accumulate --- test/rulesets/Base/base.jl | 3 +++ test/rulesets/Base/broadcast.jl | 6 ++++-- test/rulesets/LinearAlgebra/factorization.jl | 5 ++++- test/test_util.jl | 12 +++++++----- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 237ab33e2..70a905638 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -134,11 +134,14 @@ z̄ = rand(3, 5) @test ds === NO_FIELDS_RULE + + #== TODO: reanable me @test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄)) @test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄)) test_accumulation(rand(3, 2), dx, z̄, z̄ * y') test_accumulation(rand(2, 5), dy, z̄, x' * z̄) + ==# end @testset "hypot(x, y)" begin diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 74660db9e..45e85031c 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -8,13 +8,15 @@ @test extern(dx(One())) == cos.(x) x̄, ȳ = rand(), rand() - @test isequal( + + + @test_skip isequal( extern(ChainRules.accumulate(x̄, dx, ȳ)), x̄ .+ ȳ .* cos.(x) ) x̄, ȳ = Zero(), rand(3, 3) - @test extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x) + @test_skip extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x) end end end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 7713a59f5..5f88288f0 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -16,6 +16,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo end @test_throws ArgumentError rrule(getproperty, F, :Vt) end + #== TODO: re-enable me @testset "accumulate!" begin X = [1.0 2.0; 3.0 4.0; 5.0 6.0] F, dX = rrule(svd, X) @@ -29,6 +30,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @test X̄.S ≈ ones(2) atol=1e-6 @test X̄.V ≈ ones(2, 2) atol=1e-6 end + ==# @testset "Helper functions" begin X = randn(rng, 10, 10) Y = randn(rng, 10, 10) @@ -44,7 +46,8 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo V = generate_well_conditioned_matrix(rng, 10) F, dX = rrule(cholesky, X) for p in [:U, :L] - Y, (dF, dp) = rrule(getproperty, F, p) + Y, (dself, dF, dp) = rrule(getproperty, F, p) + @test dself === NO_FIELDS_RULE @test dp isa ChainRules.DNERule Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y))) # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` diff --git a/test/test_util.jl b/test/test_util.jl index ff8fd2267..24ff36bed 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -198,7 +198,7 @@ function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) end function test_accumulation(x̄, dx, ȳ, partial) - #@test all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) + @test_skip all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) test_accumulate(x̄, dx, ȳ, partial) test_accumulate!(x̄, dx, ȳ, partial) test_store!(x̄, dx, ȳ, partial) @@ -206,19 +206,19 @@ function test_accumulation(x̄, dx, ȳ, partial) end function test_accumulate(x̄::Zero, dx, ȳ, partial) - #@test extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) + @test_skip extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) return nothing end function test_accumulate(x̄::Number, dx, ȳ, partial) - #@test extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) + @test_skip extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) return nothing end function test_accumulate(x̄::AbstractArray, dx, ȳ, partial) x̄_old = copy(x̄) - #@test all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) - #@test x̄ == x̄_old + @test_skip all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) + @test x̄ == x̄_old return nothing end @@ -231,6 +231,7 @@ end function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial) x̄_copy = copy(x̄) + #TODO Reeable me #accumulate!(x̄_copy, dx, ȳ) #@test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) return nothing @@ -241,6 +242,7 @@ test_store!(x̄::Number, dx, ȳ, partial) = nothing function test_store!(x̄::AbstractArray, dx, ȳ, partial) x̄_copy = copy(x̄) + # TODO: renable me #store!(x̄_copy, dx, ȳ) #@test all(x̄_copy .≈ extern(partial)) return nothing From 6c7478d89c4941900a963230deb2f23e67867d81 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 30 Aug 2019 19:56:27 +0100 Subject: [PATCH 06/38] WIP: --- proto.jl | 52 ++++++++++++++++++ src/ChainRules.jl | 9 ++-- src/rulesets/Base/array.jl | 57 +++++++++++--------- src/rulesets/Base/base.jl | 4 +- src/rulesets/Base/broadcast.jl | 2 +- src/rulesets/Base/mapreduce.jl | 14 ++--- src/rulesets/LinearAlgebra/blas.jl | 18 +++---- src/rulesets/LinearAlgebra/dense.jl | 24 ++++----- src/rulesets/LinearAlgebra/factorization.jl | 8 +-- src/rulesets/LinearAlgebra/structured.jl | 26 ++++----- test/rulesets/Base/base.jl | 6 +-- test/rulesets/LinearAlgebra/factorization.jl | 2 +- test/test_util.jl | 6 +-- 13 files changed, 143 insertions(+), 85 deletions(-) create mode 100644 proto.jl diff --git a/proto.jl b/proto.jl new file mode 100644 index 000000000..52de5fc2d --- /dev/null +++ b/proto.jl @@ -0,0 +1,52 @@ +using Pkg: @pkg_str +pkg"activate /Users/oxinabox/JuliaEnvs/ChainRulesWorld/" +using Revise + + +include("/Users/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRulesCore.jl/test/runtests.jl") + +include("/Users/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRules.jl/test/runtests.jl") + + +using FiniteDifferences +using Test +using ChainRules +using Random + +const accumulate = ChainRules.ChainRulesCore.accumulate +const accumulate! = ChainRules.ChainRulesCore.accumulate! +const add = ChainRules.ChainRulesCore.add +includet("/Users/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRules.jl/test/test_util.jl") + +#== +Test Summary: | Pass Fail Error Total +ChainRules | 2271 38 88 2397 + +==# +using MacroTools +using MacroTools: textwalk + +code = """ +Rule(x -> -sin(x)) +Rule(x -> 1 + tan(x)^2) +"""; + +after = textwalk(code) do expr + @capture(expr, Rule(v_)) && return MacroTools.postwalk(MacroTools.unblock, v) + return expr +end + +println(after) + +############## + + +code = """ +y->-sin(x) +"""; + +after = textwalk(code) do expr + @capture(expr, v_) && return v +end + +println(after) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index d430c36bd..d8ab13d49 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -24,11 +24,12 @@ end include("helper_functions.jl") -include("rulesets/Base/base.jl") +#include("rulesets/Base/base.jl") include("rulesets/Base/array.jl") -include("rulesets/Base/broadcast.jl") -include("rulesets/Base/mapreduce.jl") +#include("rulesets/Base/broadcast.jl") +#include("rulesets/Base/mapreduce.jl") +#== include("rulesets/Statistics/statistics.jl") include("rulesets/LinearAlgebra/utils.jl") @@ -51,5 +52,5 @@ function __init__() using .SpecialFunctionsGlue end end - +==# end # module diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b66209fc1..e8e0b5309 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -3,14 +3,14 @@ ##### function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) - return reshape(A, dims), (NO_FIELDS_RULE, Rule(Ȳ->reshape(Ȳ, dims)), DNERule()) + return reshape(A, dims), Ȳ -> (NO_FIELDS, reshape(Ȳ, dims), DNE()) end function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) - Y, (nofields, rule, dne) = rrule(reshape, A, dims)[2] - @assert no_fields === NO_FIELDS_RULE - @assert dne === DNERule() - return Y, (NO_FIELDS_RULE, rule, fill(DNERule(), length(dims))...) + return ( + reshape(A, dims...), + Ȳ -> (NO_FIELDS, reshape(Ȳ, dims), fill(DNE(), length(dims))...) + ) end ##### @@ -18,17 +18,20 @@ end ##### function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) - Y = hcat(A, Bs...) - Xs = (A, Bs...) - rules = ntuple(length(Bs) + 1) do i - l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) - u = l + size(Xs[i], 2) - dim = u > l + 1 ? (l+1:u) : u - # NOTE: The copy here is defensive, since `selectdim` returns a view which we can - # materialize with `copy` - Rule(Ȳ->copy(selectdim(Ȳ, 2, dim))) - end - return Y, (NO_FIELDS_RULE, rules...) + function hcat_pullback(Ȳ) + Xs = (A, Bs...) + ntuple(length(Bs) + 1) do full_i + full_i == 1 && return NO_FIELDS + + i = full_i - 1 + l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) + u = l + size(Xs[i], 2) + dim = u > l + 1 ? (l+1:u) : u + # NOTE: The copy here is defensive, since `selectdim` returns a view which we can + # materialize with `copy` + copy(selectdim(Ȳ, 2, dim)) + end + return hcat(A, Bs...), hcat_pullback end ##### @@ -36,15 +39,17 @@ end ##### function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) - Y = vcat(A, Bs...) - n = size(A, 1) - ∂A = Rule(Ȳ->copy(selectdim(Ȳ, 1, 1:n))) - ∂Bs = ntuple(length(Bs)) do i - l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0) - u = l + size(Bs[i], 1) - Rule(Ȳ->copy(selectdim(Ȳ, 1, l+1:u))) + function vcat_pullback(Ȳ) + n = size(A, 1) + ∂A = copy(selectdim(Ȳ, 1, 1:n)) + ∂Bs = ntuple(length(Bs)) do i + l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0) + u = l + size(Bs[i], 1) + copy(selectdim(Ȳ, 1, l+1:u)) + end + return (NO_FIELDS, ∂A, ∂Bs...) end - return Y, (NO_FIELDS_RULE, ∂A, ∂Bs...) + return vcat(A, Bs...), vcat_pullback end ##### @@ -52,9 +57,9 @@ end ##### function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) - return fill(value, dims), (NO_FIELDS_RULE, Rule(sum), DNERule()) + return fill(value, dims), Ȳ -> (NO_FIELDS, sum(Ȳ), DNE()) end function rrule(::typeof(fill), value::Any, dims::Int...) - return fill(value, dims), (NO_FIELDS_RULE, Rule(sum), ntuple(_->DNERule(), length(dims))...) + return fill(value, dims), Ȳ -> (NO_FIELDS, sum(Ȳ), ntuple(_->DNE(), length(dims))...) end diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index c679087d3..3958f00bd 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -104,7 +104,7 @@ # product rule requires special care for arguments where `mul` is non-commutative frule(::typeof(*), x::Number, y::Number) = x * y, (ZERO_RULE, Rule((Δx, Δy) -> Δx * y + x * Δy)) -rrule(::typeof(*), x::Number, y::Number) = x * y, (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) +rrule(::typeof(*), x::Number, y::Number) = x * y, (NO_FIELDS, Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) frule(::typeof(identity), x) = x, (ZERO_RULE, Rule(identity)) -rrule(::typeof(identity), x) = x, (NO_FIELDS_RULE, Rule(identity)) +rrule(::typeof(identity), x) = x, (NO_FIELDS, Rule(identity)) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index c57ad8c59..63a5086c3 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -20,5 +20,5 @@ end function rrule(::typeof(broadcast), f, x) values, derivs = _cast_diff(f, x) - return values, (NO_FIELDS_RULE, DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs))) + return values, (NO_FIELDS, DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs))) end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 7a01823b0..7f2b5fc4e 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -12,7 +12,7 @@ function rrule(::typeof(map), f, xs...) end end end - return y, (NO_FIELDS_RULE, DNERule(), ∂xs...) + return y, (NO_FIELDS, DNERule(), ∂xs...) end ##### @@ -34,7 +34,7 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr) extern(∂xi(ȳi)) end end - return y, (NO_FIELDS_RULE, DNERule(), DNERule(), ∂x) + return y, (NO_FIELDS, DNERule(), DNERule(), ∂x) end eval(Expr(:function, sig, body)) end @@ -45,21 +45,21 @@ end frule(::typeof(sum), x) = (sum(x), (ZERO_RULE, Rule(sum))) -rrule(::typeof(sum), x) = (sum(x), (NO_FIELDS_RULE, Rule(cast))) +rrule(::typeof(sum), x) = (sum(x), (NO_FIELDS, Rule(cast))) function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims) - return y, (NO_FIELDS_RULE, DNERule(), ∂x) + return y, (NO_FIELDS, DNERule(), ∂x) end function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:) y, (no_fields, _, ∂x) = rrule(sum, identity, x; dims=dims) - @assert(no_fields === NO_FIELDS_RULE) - return y, (NO_FIELDS_RULE, ∂x) + @assert(no_fields === NO_FIELDS) + return y, (NO_FIELDS, ∂x) end function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) y = sum(abs2, x; dims=dims) ∂x = Rule(ȳ -> 2ȳ .* x) - return y, (NO_FIELDS_RULE, DNERule(), ∂x) + return y, (NO_FIELDS, DNERule(), ∂x) end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 9f873aedd..d2edc7063 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -21,7 +21,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) Ω = BLAS.dot(n, X, incx, Y, incy) ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) - return Ω, (NO_FIELDS_RULE, DNERule(), _rule_via(∂X), DNERule(), _rule_via(∂Y), DNERule()) + return Ω, (NO_FIELDS, DNERule(), _rule_via(∂X), DNERule(), _rule_via(∂Y), DNERule()) end ##### @@ -35,13 +35,13 @@ end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) - return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * @thunk(x * inv(Ω)))) + return Ω, (NO_FIELDS, Rule(ΔΩ -> ΔΩ * @thunk(x * inv(Ω)))) end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) ∂X = ΔΩ -> scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) - return Ω, (NO_FIELDS_RULE, DNERule(), _rule_via(∂X), DNERule()) + return Ω, (NO_FIELDS, DNERule(), _rule_via(∂X), DNERule()) end ##### @@ -53,13 +53,13 @@ function frule(::typeof(BLAS.asum), x) end function rrule(::typeof(BLAS.asum), x) - return BLAS.asum(x), (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * cast(sign, x))) + return BLAS.asum(x), (NO_FIELDS, Rule(ΔΩ -> ΔΩ * cast(sign, x))) end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, sign.(X), incx, _zeros(X), incx), incx) - return Ω, (NO_FIELDS_RULE, DNERule(), _rule_via(∂X), DNERule()) + return Ω, (NO_FIELDS, DNERule(), _rule_via(∂X), DNERule()) end ##### @@ -76,13 +76,13 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, ∂A = Rule(ȳ -> α * x * ȳ', (Ā, ȳ) -> ger!(α, x, ȳ, Ā)) ∂x = Rule(ȳ -> gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄)) end - return y, (NO_FIELDS_RULE, DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x) + return y, (NO_FIELDS, DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x) end function rrule(::typeof(gemv), tA::Char, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat y, (dtA, _, dA, dx) = rrule(gemv, tA, one(T), A, x) - return y, (NO_FIELDS_RULE, dtA, dA, dx) + return y, (NO_FIELDS, dtA, dA, dx) end ##### @@ -118,11 +118,11 @@ function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, (B̄, C̄) -> gemm!('T', 'T', α, C̄, A, β, B̄)) end end - return C, (NO_FIELDS_RULE, DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B) + return C, (NO_FIELDS, DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B) end function rrule(::typeof(gemm), tA::Char, tB::Char, A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B) - return C, (NO_FIELDS_RULE, dtA, dtB, dA, dB) + return C, (NO_FIELDS, dtA, dtB, dA, dB) end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 010a7dce1..1e8886698 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -13,7 +13,7 @@ function frule(::typeof(dot), x, y) end function rrule(::typeof(dot), x, y) - return dot(x, y), (NO_FIELDS_RULE, (Rule(ΔΩ -> ΔΩ * cast(y)), Rule(ΔΩ -> cast(x) * ΔΩ))) + return dot(x, y), (NO_FIELDS, (Rule(ΔΩ -> ΔΩ * cast(y)), Rule(ΔΩ -> cast(x) * ΔΩ))) end ##### @@ -29,7 +29,7 @@ end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω') - return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> m * ΔΩ * Ω')) + return Ω, (NO_FIELDS, Rule(ΔΩ -> m * ΔΩ * Ω')) end ##### @@ -43,7 +43,7 @@ end function rrule(::typeof(det), x) Ω, m = det(x), @thunk(inv(x)') - return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> Ω * ΔΩ * m)) + return Ω, (NO_FIELDS, Rule(ΔΩ -> Ω * ΔΩ * m)) end ##### @@ -57,7 +57,7 @@ end function rrule(::typeof(logdet), x) Ω, m = logdet(x), @thunk(inv(x)') - return Ω, (NO_FIELDS_RULE, Rule(ΔΩ -> ΔΩ * m)) + return Ω, (NO_FIELDS, Rule(ΔΩ -> ΔΩ * m)) end ##### @@ -65,14 +65,14 @@ end ##### frule(::typeof(tr), x) = (tr(x), (ZERO_RULE, Rule(Δx -> tr(extern(Δx))))) -rrule(::typeof(tr), x) = (tr(x), (NO_FIELDS_RULE, Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1)))))) +rrule(::typeof(tr), x) = (tr(x), (NO_FIELDS, Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1)))))) ##### ##### `*` ##### function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) - return A * B, (NO_FIELDS_RULE, Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) + return A * B, (NO_FIELDS, Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) end ##### @@ -84,7 +84,7 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr S = T.name.wrapper ∂A = Rule(Ȳ -> Ȳ / B') ∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B'))) - return Y, (NO_FIELDS_RULE, ∂A, ∂B) + return Y, (NO_FIELDS, ∂A, ∂B) end function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) @@ -94,7 +94,7 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R C, dC = rrule(adjoint, Cᵀ) ∂A = Rule(dA∘dAᵀ∘dC) ∂B = Rule(dA∘dBᵀ∘dC) - return C, (NO_FIELDS_RULE, ∂A, ∂B) + return C, (NO_FIELDS, ∂A, ∂B) end ##### @@ -106,7 +106,7 @@ function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMa S = T.name.wrapper ∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y')) ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (NO_FIELDS_RULE, ∂A, ∂B) + return Y, (NO_FIELDS, ∂A, ∂B) end function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) @@ -119,7 +119,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R Ā end ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (NO_FIELDS_RULE, ∂A, ∂B) + return Y, (NO_FIELDS, ∂A, ∂B) end ##### @@ -131,9 +131,9 @@ function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2) u = y^(1-p) ∂A = Rule(ȳ -> ȳ .* u .* abs.(A).^p ./ A) ∂p = Rule(ȳ -> ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p) - return y, (NO_FIELDS_RULE, ∂A, ∂p) + return y, (NO_FIELDS, ∂A, ∂p) end function rrule(::typeof(norm), x::Real, p::Real=2) - return norm(x, p), (NO_FIELDS_RULE, Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) + return norm(x, p), (NO_FIELDS, Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 6483711b0..cf89e9351 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -10,7 +10,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) ∂X = Rule() do Ȳ::NamedTuple{(:U,:S,:V)} svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) end - return F, (NO_FIELDS_RULE, ∂X) + return F, (NO_FIELDS, ∂X) end function rrule(::typeof(getproperty), F::SVD, x::Symbol) @@ -25,7 +25,7 @@ function rrule(::typeof(getproperty), F::SVD, x::Symbol) throw(ArgumentError("Vt is unsupported; use V and transpose the result")) end update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x) - return getproperty(F, x), (NO_FIELDS_RULE, Rule(rule, update), DNERule()) + return getproperty(F, x), (NO_FIELDS, Rule(rule, update), DNERule()) end function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) @@ -66,7 +66,7 @@ end function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) F = cholesky(X) ∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) - return F, (NO_FIELDS_RULE, ∂X) + return F, (NO_FIELDS, ∂X) end function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) @@ -83,7 +83,7 @@ function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) ∂F = Ȳ->UpperTriangular(Ȳ') end end - return getproperty(F, x), (NO_FIELDS_RULE, Rule(∂F), DNERule()) + return getproperty(F, x), (NO_FIELDS, Rule(∂F), DNERule()) end # See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 9e98b9d99..915cacca9 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -4,15 +4,15 @@ ##### `Diagonal` ##### -rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), (NO_FIELDS_RULE, Rule(diag)) +rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), (NO_FIELDS, Rule(diag)) -rrule(::typeof(diag), A::AbstractMatrix) = diag(A), (NO_FIELDS_RULE, Rule(Diagonal)) +rrule(::typeof(diag), A::AbstractMatrix) = diag(A), (NO_FIELDS, Rule(Diagonal)) ##### ##### `Symmetric` ##### -rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), (NO_FIELDS_RULE, Rule(_symmetric_back)) +rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), (NO_FIELDS, Rule(_symmetric_back)) _symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ @@ -22,26 +22,26 @@ _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### # TODO: Deal with complex-valued arrays as well -rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), (NO_FIELDS_RULE, Rule(adjoint)) -rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), (NO_FIELDS_RULE, Rule(vec∘adjoint)) +rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), (NO_FIELDS, Rule(adjoint)) +rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), (NO_FIELDS, Rule(vec∘adjoint)) -rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), (NO_FIELDS_RULE, Rule(adjoint)) -rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), (NO_FIELDS_RULE, Rule(vec∘adjoint)) +rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), (NO_FIELDS, Rule(adjoint)) +rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), (NO_FIELDS, Rule(vec∘adjoint)) ##### ##### `Transpose` ##### -rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), (NO_FIELDS_RULE, Rule(transpose)) -rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), (NO_FIELDS_RULE, Rule(vec∘transpose)) +rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), (NO_FIELDS, Rule(transpose)) +rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), (NO_FIELDS, Rule(vec∘transpose)) -rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), (NO_FIELDS_RULE, Rule(transpose)) -rrule(::typeof(transpose), A::AbstractVector) = transpose(A), (NO_FIELDS_RULE, Rule(vec∘transpose)) +rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), (NO_FIELDS, Rule(transpose)) +rrule(::typeof(transpose), A::AbstractVector) = transpose(A), (NO_FIELDS, Rule(vec∘transpose)) ##### ##### Triangular matrices ##### -rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), (NO_FIELDS_RULE, Rule(Matrix)) +rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), (NO_FIELDS, Rule(Matrix)) -rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), (NO_FIELDS_RULE, Rule(Matrix)) +rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), (NO_FIELDS, Rule(Matrix)) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 70a905638..3ff1d75e8 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -65,7 +65,7 @@ r, (ds, df1, df2) = rrule(atan, x, y) @test r === ratan - @test ds === NO_FIELDS_RULE + @test ds === NO_FIELDS @test df1(1) + df2(2) === datan end @@ -81,7 +81,7 @@ r, (ds, df) = rrule(sincos, x) @test r === rsincos @test df(1, 2) === dsincos - @test ds === NO_FIELDS_RULE + @test ds === NO_FIELDS end end end # Trig @@ -133,7 +133,7 @@ @test z == x * y z̄ = rand(3, 5) - @test ds === NO_FIELDS_RULE + @test ds === NO_FIELDS #== TODO: reanable me @test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄)) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 5f88288f0..a8d925f23 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -47,7 +47,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo F, dX = rrule(cholesky, X) for p in [:U, :L] Y, (dself, dF, dp) = rrule(getproperty, F, p) - @test dself === NO_FIELDS_RULE + @test dself === NO_FIELDS @test dp isa ChainRules.DNERule Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y))) # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` diff --git a/test/test_util.jl b/test/test_util.jl index 24ff36bed..10a12b97d 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -29,7 +29,7 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa @test fx == f(x) # Check we still get the normal value, right # No internal fields - rule===rrule && @test ∂self_rule === NO_FIELDS_RULE + rule===rrule && @test ∂self_rule === NO_FIELDS rule===frule && @test ∂self_rule === ZERO_RULE # Check that we get the derivative right: @@ -111,7 +111,7 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm fx, (∂self_rule, dx_rule) = ChainRules.rrule(f, x) @test fx ≈ f(x) - @test ∂self_rule === NO_FIELDS_RULE # No internal fields + @test ∂self_rule === NO_FIELDS # No internal fields # Correctness testing via finite differencing. x̄_ad = dx_rule(ȳ) @@ -160,7 +160,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm self_rule = rules[1] arg_rules = rules[2:end] - @test self_rule === NO_FIELDS_RULE + @test self_rule === NO_FIELDS # Correctness testing via finite differencing. x̄s_ad = map(arg_rules) do rule From 05d0c0638da6d7726d9badb2be369179a2a76fd2 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 2 Sep 2019 15:31:01 +0100 Subject: [PATCH 07/38] all real scalar rules working --- src/ChainRules.jl | 2 +- src/rulesets/Base/array.jl | 1 + test/rulesets/Base/base.jl | 33 ++++++++++++++++++--------------- test/test_util.jl | 23 ++++++++++++++++------- 4 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index d8ab13d49..3a6693b0e 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -24,7 +24,7 @@ end include("helper_functions.jl") -#include("rulesets/Base/base.jl") +include("rulesets/Base/base.jl") include("rulesets/Base/array.jl") #include("rulesets/Base/broadcast.jl") #include("rulesets/Base/mapreduce.jl") diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index e8e0b5309..fba05cd50 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -31,6 +31,7 @@ function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) # materialize with `copy` copy(selectdim(Ȳ, 2, dim)) end + end return hcat(A, Bs...), hcat_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 3ff1d75e8..c76032eaa 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -58,29 +58,30 @@ u = x^2 + y^2 datan = y/u - 2x/u - r, (ds, df) = frule(atan, x, y) + r, df = frule(atan, x, y) @test r === ratan - @test df(1, 2) === datan - @test ds === ZERO_RULE + @test df(NamedTuple(), 1, 2) === datan - r, (ds, df1, df2) = rrule(atan, x, y) + r, pullback = rrule(atan, x, y) + (ds, df1, df2) = pullback(1) @test r === ratan @test ds === NO_FIELDS - @test df1(1) + df2(2) === datan + @test df1 + 2df2 === datan end @testset "sincos" begin rsincos = sincos(x) dsincos = cos(x) - 2sin(x) - r, (ds, df1, df2) = frule(sincos, x) + r, pushforward = frule(sincos, x) @test r === rsincos - @test df1(1) + df2(2) === dsincos - @test ds === ZERO_RULE + df1, df2 = pushforward(NamedTuple(), 1) + @test df1 + 2df2 === dsincos - r, (ds, df) = rrule(sincos, x) + r, pullback = rrule(sincos, x) @test r === rsincos - @test df(1, 2) === dsincos + ds, df = pullback(1, 2) + @test df === dsincos @test ds === NO_FIELDS end end @@ -126,29 +127,31 @@ end end + #== TODO Renable me @testset "*(x, y)" begin x, y = rand(3, 2), rand(2, 5) - z, (ds, dx, dy) = rrule(*, x, y) + z, pullback = rrule(*, x, y) @test z == x * y z̄ = rand(3, 5) + (ds, dx, dy) = pullback(z̄) + @test ds === NO_FIELDS - #== TODO: reanable me @test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄)) @test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄)) test_accumulation(rand(3, 2), dx, z̄, z̄ * y') test_accumulation(rand(2, 5), dy, z̄, x' * z̄) - ==# end + ==# @testset "hypot(x, y)" begin x, y = rand(2) - h, (ds, dxy) = frule(hypot, x, y) + h, pushforward = frule(hypot, x, y) + dxy(x, y) = pushforward(NamedTuple(), x, y) # No self gradient - @test ds === ZERO_RULE @test extern(dxy(One(), Zero())) === x / h @test extern(dxy(Zero(), One())) === y / h diff --git a/test/test_util.jl b/test/test_util.jl index 10a12b97d..a84bde5ab 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -25,33 +25,42 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) res = rule(f, x) @test res !== nothing # Check the rule was defined - fx, (∂self_rule, ∂x_rule) = res + fx, prop_rule = res @test fx == f(x) # Check we still get the normal value, right - # No internal fields - rule===rrule && @test ∂self_rule === NO_FIELDS - rule===frule && @test ∂self_rule === ZERO_RULE + if rule == rrule + ∂self, ∂x = prop_rule(1) + @test ∂self === NO_FIELDS + else # rule == frule + # Got to input extra first aguement for internals + # But it is only a dummy since this is not a functor + ∂x = prop_rule(NamedTuple(), 1) + end + # Check that we get the derivative right: if !test_wirtinger @test isapprox( - ∂x_rule(1), fdm(f, x); + ∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs... ) else + # Wirtinger not currently implemented + #== # For complex arguments, also check if the wirtinger derivative is correct ∂Re = fdm(ϵ -> f(x + ϵ), 0) ∂Im = fdm(ϵ -> f(x + im*ϵ), 0) ∂ = 0.5(∂Re - im*∂Im) ∂̅ = 0.5(∂Re + im*∂Im) @test isapprox( - wirtinger_primal(∂x_rule(1)), ∂; + wirtinger_primal(∂x), ∂; rtol=rtol, atol=atol, kwargs... ) @test isapprox( - wirtinger_conjugate(∂x_rule(1)), ∂̅; + wirtinger_conjugate(∂x), ∂̅; rtol=rtol, atol=atol, kwargs... ) + ==# end end end From 7818a634e13ece3b7ee0570106447f84c370eeb0 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 2 Sep 2019 17:58:40 +0100 Subject: [PATCH 08/38] Wirtinger scalars passing --- test/rulesets/Base/base.jl | 4 ++-- test/rulesets/Base/broadcast.jl | 4 ++-- test/test_util.jl | 26 +++++++++++--------------- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index c76032eaa..c8e57961c 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -60,7 +60,7 @@ r, df = frule(atan, x, y) @test r === ratan - @test df(NamedTuple(), 1, 2) === datan + @test df(NamedTuple(), 1, 2) === (datan,) r, pullback = rrule(atan, x, y) (ds, df1, df2) = pullback(1) @@ -150,7 +150,7 @@ @testset "hypot(x, y)" begin x, y = rand(2) h, pushforward = frule(hypot, x, y) - dxy(x, y) = pushforward(NamedTuple(), x, y) # No self gradient + dxy(x, y) = pushforward(NamedTuple(), x, y)[1] @test extern(dxy(One(), Zero())) === x / h @test extern(dxy(Zero(), One())) === y / h diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 45e85031c..e6e07db97 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -10,13 +10,13 @@ x̄, ȳ = rand(), rand() - @test_skip isequal( + @test isequal( extern(ChainRules.accumulate(x̄, dx, ȳ)), x̄ .+ ȳ .* cos.(x) ) x̄, ȳ = Zero(), rand(3, 3) - @test_skip extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x) + @test extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x) end end end diff --git a/test/test_util.jl b/test/test_util.jl index a84bde5ab..b2d916106 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -34,7 +34,7 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa else # rule == frule # Got to input extra first aguement for internals # But it is only a dummy since this is not a functor - ∂x = prop_rule(NamedTuple(), 1) + ∂x, = prop_rule(NamedTuple(), 1) end @@ -45,8 +45,6 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa rtol=rtol, atol=atol, kwargs... ) else - # Wirtinger not currently implemented - #== # For complex arguments, also check if the wirtinger derivative is correct ∂Re = fdm(ϵ -> f(x + ϵ), 0) ∂Im = fdm(ϵ -> f(x + im*ϵ), 0) @@ -60,7 +58,6 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa wirtinger_conjugate(∂x), ∂̅; rtol=rtol, atol=atol, kwargs... ) - ==# end end end @@ -207,7 +204,7 @@ function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) end function test_accumulation(x̄, dx, ȳ, partial) - @test_skip all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) + @test all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) test_accumulate(x̄, dx, ȳ, partial) test_accumulate!(x̄, dx, ȳ, partial) test_store!(x̄, dx, ȳ, partial) @@ -215,18 +212,18 @@ function test_accumulation(x̄, dx, ȳ, partial) end function test_accumulate(x̄::Zero, dx, ȳ, partial) - @test_skip extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) + @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) return nothing end function test_accumulate(x̄::Number, dx, ȳ, partial) - @test_skip extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) + @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) return nothing end function test_accumulate(x̄::AbstractArray, dx, ȳ, partial) x̄_old = copy(x̄) - @test_skip all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) + @test all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) @test x̄ == x̄_old return nothing end @@ -234,15 +231,15 @@ end test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing function test_accumulate!(x̄::Number, dx, ȳ, partial) - #@test accumulate!(x̄, dx, ȳ) ≈ accumulate(x̄, dx, ȳ) + @test accumulate!(x̄, dx, ȳ) ≈ accumulate(x̄, dx, ȳ) return nothing end function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial) x̄_copy = copy(x̄) - #TODO Reeable me - #accumulate!(x̄_copy, dx, ȳ) - #@test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) + + accumulate!(x̄_copy, dx, ȳ) + @test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) return nothing end @@ -251,8 +248,7 @@ test_store!(x̄::Number, dx, ȳ, partial) = nothing function test_store!(x̄::AbstractArray, dx, ȳ, partial) x̄_copy = copy(x̄) - # TODO: renable me - #store!(x̄_copy, dx, ȳ) - #@test all(x̄_copy .≈ extern(partial)) + store!(x̄_copy, dx, ȳ) + @test all(x̄_copy .≈ extern(partial)) return nothing end From 1a43009d6f24f09c150bb21fb94f44e29ca2382f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 3 Sep 2019 19:59:41 +0100 Subject: [PATCH 09/38] all tests in tests/rulesets/Base/base.jl passing --- src/ChainRules.jl | 12 +-- src/rulesets/Base/base.jl | 8 +- src/rulesets/LinearAlgebra/dense.jl | 115 ++++++++++++++++++---------- test/rulesets/Base/base.jl | 9 +-- test/test_util.jl | 71 ++++++++--------- 5 files changed, 123 insertions(+), 92 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 3a6693b0e..790c91497 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -29,14 +29,14 @@ include("rulesets/Base/array.jl") #include("rulesets/Base/broadcast.jl") #include("rulesets/Base/mapreduce.jl") -#== -include("rulesets/Statistics/statistics.jl") -include("rulesets/LinearAlgebra/utils.jl") -include("rulesets/LinearAlgebra/blas.jl") +#include("rulesets/Statistics/statistics.jl") + +#include("rulesets/LinearAlgebra/utils.jl") +#include("rulesets/LinearAlgebra/blas.jl") include("rulesets/LinearAlgebra/dense.jl") -include("rulesets/LinearAlgebra/structured.jl") -include("rulesets/LinearAlgebra/factorization.jl") +#include("rulesets/LinearAlgebra/structured.jl") +#include("rulesets/LinearAlgebra/factorization.jl") # Note: The following is only required because package authors sometimes do not # declare their own rules using `ChainRulesCore.jl`. For arguably good reasons. diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 3958f00bd..fb5c473dd 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -103,8 +103,8 @@ # product rule requires special care for arguments where `mul` is non-commutative -frule(::typeof(*), x::Number, y::Number) = x * y, (ZERO_RULE, Rule((Δx, Δy) -> Δx * y + x * Δy)) -rrule(::typeof(*), x::Number, y::Number) = x * y, (NO_FIELDS, Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) +frule(::typeof(*), x::Number, y::Number) = x * y, (_, Δx, Δy) -> Δx * y + x * Δy +rrule(::typeof(*), x::Number, y::Number) = x * y, (ΔΩ -> (NO_FIELDS, ΔΩ * y', x' * ΔΩ)) -frule(::typeof(identity), x) = x, (ZERO_RULE, Rule(identity)) -rrule(::typeof(identity), x) = x, (NO_FIELDS, Rule(identity)) +frule(::typeof(identity), x) = x, (_, ȳ) -> ȳ +rrule(::typeof(identity), x) = x, ȳ -> (NO_FIELDS, ȳ) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 1e8886698..42e8c22cf 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -9,11 +9,17 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} ##### function frule(::typeof(dot), x, y) - return dot(x, y), (ZERO_RULE, Rule((Δx, Δy) -> sum(Δx * cast(y)) + sum(cast(x) * Δy))) + function dot_pushforward(Δself, Δx, Δy) + sum(Δx * cast(y)) + sum(cast(x) * Δy) + end + return dot(x, y), dot_pushforward end function rrule(::typeof(dot), x, y) - return dot(x, y), (NO_FIELDS, (Rule(ΔΩ -> ΔΩ * cast(y)), Rule(ΔΩ -> cast(x) * ΔΩ))) + function dot_pullback(ΔΩ) + (NO_FIELDS, ΔΩ * cast(y), cast(x) * ΔΩ,) + end + return dot(x, y), dot_pullback end ##### @@ -23,13 +29,19 @@ end function frule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω) - return Ω, (ZERO_RULE, Rule(Δx -> m * Δx * Ω)) + function inv_pushforward(_, Δx) + m * Δx * Ω + end + return Ω, inv_pushforward end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω') - return Ω, (NO_FIELDS, Rule(ΔΩ -> m * ΔΩ * Ω')) + function inv_pullback(ΔΩ) + NO_FIELDS, m * ΔΩ * Ω' + end + return Ω, inv_pullback end ##### @@ -38,12 +50,12 @@ end function frule(::typeof(det), x) Ω, m = det(x), @thunk(inv(x)) - return Ω, (ZERO_RULE, Rule(Δx -> Ω * tr(extern(m * Δx)))) + return Ω, (_, Δx) -> Ω * tr(extern(m * Δx)) end function rrule(::typeof(det), x) Ω, m = det(x), @thunk(inv(x)') - return Ω, (NO_FIELDS, Rule(ΔΩ -> Ω * ΔΩ * m)) + return Ω, ΔΩ -> (NO_FIELDS, Ω * ΔΩ * m) end ##### @@ -52,27 +64,27 @@ end function frule(::typeof(logdet), x) Ω, m = logdet(x), @thunk(inv(x)) - return Ω, (ZERO_RULE, Rule(Δx -> tr(extern(m * Δx)))) + return Ω, (_, Δx) -> tr(extern(m * Δx)) end function rrule(::typeof(logdet), x) Ω, m = logdet(x), @thunk(inv(x)') - return Ω, (NO_FIELDS, Rule(ΔΩ -> ΔΩ * m)) + return Ω, ΔΩ -> (NO_FIELDS, ΔΩ * m) end ##### ##### `trace` ##### -frule(::typeof(tr), x) = (tr(x), (ZERO_RULE, Rule(Δx -> tr(extern(Δx))))) -rrule(::typeof(tr), x) = (tr(x), (NO_FIELDS, Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1)))))) +frule(::typeof(tr), x) = (tr(x), (_, Δx) -> tr(extern(Δx))) +rrule(::typeof(tr), x) = (tr(x), ΔΩ -> (NO_FIELDS, Diagonal(fill(ΔΩ, size(x, 1))))) ##### ##### `*` ##### function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) - return A * B, (NO_FIELDS, Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) + return A * B, Ȳ -> (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) end ##### @@ -81,20 +93,30 @@ end function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} Y = A / B - S = T.name.wrapper - ∂A = Rule(Ȳ -> Ȳ / B') - ∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B'))) - return Y, (NO_FIELDS, ∂A, ∂B) + function slash_pullback(Ȳ) + S = T.name.wrapper + ∂A = @thunk Ȳ / B' + ∂B = @thunk S(-Y' * (Ȳ / B')) + (NO_FIELDS, ∂A, ∂B) + end + return Y, slash_pullback end function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Aᵀ, dA = rrule(adjoint, A) - Bᵀ, dB = rrule(adjoint, B) - Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ) - C, dC = rrule(adjoint, Cᵀ) - ∂A = Rule(dA∘dAᵀ∘dC) - ∂B = Rule(dA∘dBᵀ∘dC) - return C, (NO_FIELDS, ∂A, ∂B) + Aᵀ, dA_pb = rrule(adjoint, A) + Bᵀ, dB_pb = rrule(adjoint, B) + Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) + C, dC_pb = rrule(adjoint, Cᵀ) + function slash_pullback(Ȳ) + _, dC = dC_pb(Ȳ) + _, dAᵀ, dBᵀ = dS_pb(dC) + + ∂A = @thunk last(dA_pb(dAᵀ)) + ∂B = @thunk last(dA_pb(dBᵀ)) + + (NO_FIELDS, ∂A, ∂B) + end + return C, slash_pullback end ##### @@ -103,23 +125,30 @@ end function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} Y = A \ B - S = T.name.wrapper - ∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y')) - ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (NO_FIELDS, ∂A, ∂B) + function forwardslash_pullback(Ȳ) + S = T.name.wrapper + ∂A = @thunk S(-(A' \ Ȳ) * Y') + ∂B = @thunk A' \ Ȳ + return NO_FIELDS, ∂A, ∂B + end + return Y, forwardslash_pullback end function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) Y = A \ B - ∂A = Rule() do Ȳ - B̄ = A' \ Ȳ - Ā = -B̄ * Y' - _add!(Ā, (B - A * Y) * B̄' / A') - _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) - Ā + function forwardslash_pullback(Ȳ) + ∂A = @thunk begin + B̄ = A' \ Ȳ + Ā = -B̄ * Y' + _add!(Ā, (B - A * Y) * B̄' / A') + _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) + Ā + end + ∂B = @thunk A' \ Ȳ + return NO_FIELDS, ∂A, ∂B end - ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (NO_FIELDS, ∂A, ∂B) + return Y, forwardslash_pullback + end ##### @@ -128,12 +157,20 @@ end function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2) y = norm(A, p) - u = y^(1-p) - ∂A = Rule(ȳ -> ȳ .* u .* abs.(A).^p ./ A) - ∂p = Rule(ȳ -> ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p) - return y, (NO_FIELDS, ∂A, ∂p) + function norm_pullback(ȳ) + u = y^(1-p) + ∂A = @thunk ȳ .* u .* abs.(A).^p ./ A + ∂p = @thunk ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p + (NO_FIELDS, ∂A, ∂p) + end + return y, norm_pullback end function rrule(::typeof(norm), x::Real, p::Real=2) - return norm(x, p), (NO_FIELDS, Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) + function norm_pullback(ȳ) + ∂x = @thunk ȳ * sign(x) + ∂p = @thunk zero(x) #TODO: should this be Zero() + (NO_FIELDS, ∂x, ∂p) + end + return norm(x, p), norm_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index c8e57961c..ccb4851ca 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -139,13 +139,12 @@ @test ds === NO_FIELDS - @test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄)) - @test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄)) + @test extern(dx) == extern(accumulate(zeros(3, 2), dx)) + @test extern(dy) == extern(accumulate(zeros(2, 5), dy)) - test_accumulation(rand(3, 2), dx, z̄, z̄ * y') - test_accumulation(rand(2, 5), dy, z̄, x' * z̄) + test_accumulation(rand(3, 2), dx, z̄ * y') + test_accumulation(rand(2, 5), dy, x' * z̄) end - ==# @testset "hypot(x, y)" begin x, y = rand(2) diff --git a/test/test_util.jl b/test/test_util.jl index b2d916106..e8f340180 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -87,13 +87,11 @@ end function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) ensure_not_running_on_functor(f, "frule_test") xs, ẋs = collect(zip(xẋs...)) - Ω, (∂self_rule, dΩ_rule) = ChainRules.frule(f, xs...) + Ω, pushforward = ChainRules.frule(f, xs...) @test f(xs...) == Ω - - @test ∂self_rule === ZERO_RULE # No internal fields + dΩ_ad = pushforward(NamedTuple(), ẋs...) # Correctness testing via finite differencing. - dΩ_ad = dΩ_rule(ẋs...) dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) @test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...) end @@ -114,19 +112,18 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm ensure_not_running_on_functor(f, "rrule_test") # Check correctness of evaluation. - fx, (∂self_rule, dx_rule) = ChainRules.rrule(f, x) + fx, pullback = ChainRules.rrule(f, x) @test fx ≈ f(x) - - @test ∂self_rule === NO_FIELDS # No internal fields - + (∂self, x̄_ad) = pullback(ȳ) + @test ∂self === NO_FIELDS # No internal fields # Correctness testing via finite differencing. - x̄_ad = dx_rule(ȳ) x̄_fd = j′vp(fdm, f, ȳ, x) @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) # Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct. - test_accumulation(x̄, dx_rule, ȳ, x̄_ad) - test_accumulation(Zero(), dx_rule, ȳ, x̄_ad) + # TODO is this test nonsense now? + test_accumulation(x̄, x̄_ad, x̄_ad) + test_accumulation(Zero(), x̄_ad, x̄_ad) end function _make_fdm_call(fdm, f, ȳ, xs, ignores) @@ -161,17 +158,15 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm # Check correctness of evaluation. xs, x̄s = collect(zip(xx̄s...)) - y, rules = rrule(f, xs...) + y, pullback = rrule(f, xs...) @test f(xs...) == y - self_rule = rules[1] - arg_rules = rules[2:end] + ∂s = pullback(ȳ) + ∂self = ∂s[1] + x̄s_ad = ∂s[2:end] @test self_rule === NO_FIELDS # Correctness testing via finite differencing. - x̄s_ad = map(arg_rules) do rule - rule isa DNERule ? DNE() : rule(ȳ) - end x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing) for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd) if x̄_fd === nothing @@ -185,8 +180,8 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm # Assuming the above to be correct, check that other ChainRules mechanisms are correct. for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad) x̄ === nothing && continue - #test_accumulation(x̄, rule, ȳ, x̄_ad) - #test_accumulation(Zero(), rule, ȳ, x̄_ad) + test_accumulation(x̄, x̄_ad) + test_accumulation(Zero(), x̄_ad) end end @@ -203,52 +198,52 @@ function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) return isapprox(extern(d_ad), d_fd; kwargs...) end -function test_accumulation(x̄, dx, ȳ, partial) +function test_accumulation(x̄, dx, partial) @test all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) - test_accumulate(x̄, dx, ȳ, partial) - test_accumulate!(x̄, dx, ȳ, partial) - test_store!(x̄, dx, ȳ, partial) + test_accumulate(x̄, dx, partial) + test_accumulate!(x̄, dx, partial) + test_store!(x̄, dx, partial) return nothing end -function test_accumulate(x̄::Zero, dx, ȳ, partial) - @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) +function test_accumulate(x̄::Zero, dx, partial) + @test extern(accumulate(x̄, dx)) ≈ extern(partial) return nothing end -function test_accumulate(x̄::Number, dx, ȳ, partial) - @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) +function test_accumulate(x̄::Number, dx, partial) + @test extern(accumulate(x̄, dx)) ≈ extern(x̄) + extern(partial) return nothing end -function test_accumulate(x̄::AbstractArray, dx, ȳ, partial) +function test_accumulate(x̄::AbstractArray, dx, partial) x̄_old = copy(x̄) - @test all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) + @test all(extern(accumulate(x̄, dx)) .≈ (extern(x̄) .+ extern(partial))) @test x̄ == x̄_old return nothing end -test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing +test_accumulate!(x̄::Zero, dx, partial) = nothing -function test_accumulate!(x̄::Number, dx, ȳ, partial) - @test accumulate!(x̄, dx, ȳ) ≈ accumulate(x̄, dx, ȳ) +function test_accumulate!(x̄::Number, dx, partial) + @test accumulate!(x̄, dx) ≈ accumulate(x̄, dx) return nothing end -function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial) +function test_accumulate!(x̄::AbstractArray, dx, partial) x̄_copy = copy(x̄) - accumulate!(x̄_copy, dx, ȳ) + accumulate!(x̄_copy, dx) @test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) return nothing end -test_store!(x̄::Zero, dx, ȳ, partial) = nothing -test_store!(x̄::Number, dx, ȳ, partial) = nothing +test_store!(x̄::Zero, dx, partial) = nothing +test_store!(x̄::Number, dx, partial) = nothing -function test_store!(x̄::AbstractArray, dx, ȳ, partial) +function test_store!(x̄::AbstractArray, dx, partial) x̄_copy = copy(x̄) - store!(x̄_copy, dx, ȳ) + store!(x̄_copy, dx) @test all(x̄_copy .≈ extern(partial)) return nothing end From 6592c95408669996a1a86cbf3441242a505a1e7d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 5 Sep 2019 14:13:25 +0100 Subject: [PATCH 10/38] Fixup Base tests to match frule not returning a tuple --- proto.jl | 52 -------------------------------------- test/rulesets/Base/base.jl | 11 ++++---- 2 files changed, 6 insertions(+), 57 deletions(-) delete mode 100644 proto.jl diff --git a/proto.jl b/proto.jl deleted file mode 100644 index 52de5fc2d..000000000 --- a/proto.jl +++ /dev/null @@ -1,52 +0,0 @@ -using Pkg: @pkg_str -pkg"activate /Users/oxinabox/JuliaEnvs/ChainRulesWorld/" -using Revise - - -include("/Users/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRulesCore.jl/test/runtests.jl") - -include("/Users/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRules.jl/test/runtests.jl") - - -using FiniteDifferences -using Test -using ChainRules -using Random - -const accumulate = ChainRules.ChainRulesCore.accumulate -const accumulate! = ChainRules.ChainRulesCore.accumulate! -const add = ChainRules.ChainRulesCore.add -includet("/Users/oxinabox/JuliaEnvs/ChainRulesWorld/ChainRules.jl/test/test_util.jl") - -#== -Test Summary: | Pass Fail Error Total -ChainRules | 2271 38 88 2397 - -==# -using MacroTools -using MacroTools: textwalk - -code = """ -Rule(x -> -sin(x)) -Rule(x -> 1 + tan(x)^2) -"""; - -after = textwalk(code) do expr - @capture(expr, Rule(v_)) && return MacroTools.postwalk(MacroTools.unblock, v) - return expr -end - -println(after) - -############## - - -code = """ -y->-sin(x) -"""; - -after = textwalk(code) do expr - @capture(expr, v_) && return v -end - -println(after) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index ccb4851ca..9f90fa180 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -54,18 +54,19 @@ @testset "Multivariate" begin x, y = rand(2) @testset "atan2" begin + # https://en.wikipedia.org/wiki/Atan2 ratan = atan(x, y) # https://en.wikipedia.org/wiki/Atan2 u = x^2 + y^2 datan = y/u - 2x/u - r, df = frule(atan, x, y) + r, pushforward = frule(atan, x, y) @test r === ratan - @test df(NamedTuple(), 1, 2) === (datan,) + @test pushforward(NamedTuple(), 1, 2) === datan r, pullback = rrule(atan, x, y) - (ds, df1, df2) = pullback(1) @test r === ratan - @test ds === NO_FIELDS + dself, df1, df2 = pullback(1) + @test dself == NO_FIELDS @test df1 + 2df2 === datan end @@ -149,7 +150,7 @@ @testset "hypot(x, y)" begin x, y = rand(2) h, pushforward = frule(hypot, x, y) - dxy(x, y) = pushforward(NamedTuple(), x, y)[1] + dxy(x, y) = pushforward(NamedTuple(), x, y) @test extern(dxy(One(), Zero())) === x / h @test extern(dxy(Zero(), One())) === y / h From 74434c5f48667ef579f1b22feecd03873b6b5fa2 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 5 Sep 2019 14:48:47 +0100 Subject: [PATCH 11/38] attay test passing --- src/rulesets/Base/array.jl | 2 +- test/rulesets/Base/array.jl | 54 ++++++++++++++++++++++--------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index fba05cd50..3be8994c8 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -20,7 +20,7 @@ end function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) function hcat_pullback(Ȳ) Xs = (A, Bs...) - ntuple(length(Bs) + 1) do full_i + ntuple(length(Bs) + 2) do full_i full_i == 1 && return NO_FIELDS i = full_i - 1 diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 3a70b9ecd..4c7a00d5e 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,19 +1,23 @@ @testset "reshape" begin rng = MersenneTwister(1) A = randn(rng, 4, 5) - B, (dA, dd) = rrule(reshape, A, (5, 4)) + B, pullback = rrule(reshape, A, (5, 4)) @test B == reshape(A, (5, 4)) - @test dd isa ChainRules.DNERule Ȳ = randn(rng, 4, 5) - Ā = dA(Ȳ) + + (s̄, Ā, d̄) = pullback(Ȳ) + @test s̄ == NO_FIELDS + @test d̄ isa DNE @test Ā == reshape(Ȳ, (5, 4)) - B, (dA, dd1, dd2) = rrule(reshape, A, 5, 4) + B, pullback = rrule(reshape, A, 5, 4) @test B == reshape(A, 5, 4) - @test dd1 isa ChainRules.DNERule - @test dd2 isa ChainRules.DNERule + Ȳ = randn(rng, 4, 5) - Ā = dA(Ȳ) + (s̄, Ā, d̄1, d̄2) = pullback(Ȳ) + @test s̄ == NO_FIELDS + @test d̄1 isa DNE + @test d̄2 isa DNE @test Ā == reshape(Ȳ, 5, 4) end @@ -22,12 +26,14 @@ end A = randn(rng, 3, 2) B = randn(rng, 3) C = randn(rng, 3, 3) - H, (dA, dB, dC) = rrule(hcat, A, B, C) + H, pullback = rrule(hcat, A, B, C) @test H == hcat(A, B, C) H̄ = randn(rng, 3, 6) - @test dA(H̄) ≈ view(H̄, :, 1:2) - @test dB(H̄) ≈ view(H̄, :, 3) - @test dC(H̄) ≈ view(H̄, :, 4:6) + (ds, dA, dB, dC) = pullback(H̄) + @test ds == NO_FIELDS + @test dA ≈ view(H̄, :, 1:2) + @test dB ≈ view(H̄, :, 3) + @test dC ≈ view(H̄, :, 4:6) end @testset "vcat" begin @@ -35,22 +41,28 @@ end A = randn(rng, 2, 4) B = randn(rng, 1, 4) C = randn(rng, 3, 4) - V, (dA, dB, dC) = rrule(vcat, A, B, C) + V, pullback = rrule(vcat, A, B, C) @test V == vcat(A, B, C) V̄ = randn(rng, 6, 4) - @test dA(V̄) ≈ view(V̄, 1:2, :) - @test dB(V̄) ≈ view(V̄, 3:3, :) - @test dC(V̄) ≈ view(V̄, 4:6, :) + (ds, dA, dB, dC) = pullback(V̄) + @test ds == NO_FIELDS + @test dA ≈ view(V̄, 1:2, :) + @test dB ≈ view(V̄, 3:3, :) + @test dC ≈ view(V̄, 4:6, :) end @testset "fill" begin - y, (dv, dd) = rrule(fill, 44, 4) + y, pullback = rrule(fill, 44, 4) @test y == [44, 44, 44, 44] - @test dd isa ChainRules.DNERule - @test dv(ones(Int, 4)) == 4 + (ds, dv, dd) = pullback(ones(4)) + @test ds === NO_FIELDS + @test dd isa DNE + @test dv == 4 - y, (dv, dd) = rrule(fill, 2.0, (3, 3, 3)) + y, pullback = rrule(fill, 2.0, (3, 3, 3)) @test y == fill(2.0, (3, 3, 3)) - @test dd isa ChainRules.DNERule - @test dv(ones(3, 3, 3)) ≈ 27.0 + (ds, dv, dd) = pullback(ones(3, 3, 3)) + @test ds === NO_FIELDS + @test dd isa DNE + @test dv ≈ 27.0 end From 1937ec123a82508d93d6c51a5d023edf28ad3e51 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 5 Sep 2019 16:10:55 +0100 Subject: [PATCH 12/38] Broadcast fixed --- src/ChainRules.jl | 2 +- src/rulesets/Base/broadcast.jl | 8 ++++---- test/rulesets/Base/broadcast.jl | 28 +++++++++++++++++++--------- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 790c91497..d11d66c77 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -26,7 +26,7 @@ include("helper_functions.jl") include("rulesets/Base/base.jl") include("rulesets/Base/array.jl") -#include("rulesets/Base/broadcast.jl") +include("rulesets/Base/broadcast.jl") #include("rulesets/Base/mapreduce.jl") diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 63a5086c3..8edb7fa5e 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -5,9 +5,9 @@ without relying on inference hacks unless we have something akin to https://github.com/JuliaLang/julia/issues/22129. =# function _cast_diff(f, x) - element_rule = u -> begin + function element_rule(u) fu, du = frule(f, u) - fu, extern(du(One())) + fu, extern(du(NamedTuple(), One())) end results = broadcast(element_rule, x) return first.(results), last.(results) @@ -15,10 +15,10 @@ end function frule(::typeof(broadcast), f, x) Ω, ∂x = _cast_diff(f, x) - return Ω, (ZERO_RULE, Rule((_, Δx) -> Δx * cast(∂x))) + return Ω, (_, Δf, Δx) -> Δx * cast(∂x) end function rrule(::typeof(broadcast), f, x) values, derivs = _cast_diff(f, x) - return values, (NO_FIELDS, DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs))) + return values, ΔΩ -> (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs))) end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index e6e07db97..4dd4593ae 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -1,22 +1,32 @@ @testset "broadcast" begin - @testset "Misc. Tests" begin - @testset "sin.(x)" begin + @testset "sin.(x)" begin + @testset "rrule" begin x = rand(3, 3) - y, (dsin, dx) = rrule(broadcast, sin, x) - + y, pullback = rrule(broadcast, sin, x) @test y == sin.(x) - @test extern(dx(One())) == cos.(x) + (dself, dsin, dx) = pullback(One()) + @test dself == NO_FIELDS + @test dsin == DNE() + @test extern(extern(dx)) == cos.(x) x̄, ȳ = rand(), rand() - - + ∂x = pullback(ȳ)[3] @test isequal( - extern(ChainRules.accumulate(x̄, dx, ȳ)), + extern(ChainRules.accumulate(x̄, ∂x)), x̄ .+ ȳ .* cos.(x) ) x̄, ȳ = Zero(), rand(3, 3) - @test extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x) + ∂x = pullback(ȳ)[3] + @test extern(extern(accumulate(x̄, ∂x))) == ȳ .* cos.(x) + end + @testset "frule" begin + x = rand(3, 3) + y, pushforward = frule(broadcast, sin, x) + @test y == sin.(x) + + ẏ = pushforward(NamedTuple(), NamedTuple(), One()) + @test extern(ẏ) == cos.(x) end end end From facf9949345188c10060e9a4970818a95aa81497 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 5 Sep 2019 20:03:49 +0100 Subject: [PATCH 13/38] WIP fixing up mapreduce file --- src/ChainRules.jl | 4 +-- src/helper_functions.jl | 5 ++++ src/rulesets/Base/mapreduce.jl | 49 +++++++++++++++++++++------------- test/test_util.jl | 8 +++--- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index d11d66c77..61f502e34 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -27,7 +27,7 @@ include("helper_functions.jl") include("rulesets/Base/base.jl") include("rulesets/Base/array.jl") include("rulesets/Base/broadcast.jl") -#include("rulesets/Base/mapreduce.jl") +include("rulesets/Base/mapreduce.jl") #include("rulesets/Statistics/statistics.jl") @@ -52,5 +52,5 @@ function __init__() using .SpecialFunctionsGlue end end -==# + end # module diff --git a/src/helper_functions.jl b/src/helper_functions.jl index 392ae2de6..ddd371e60 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -24,7 +24,12 @@ function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns return _update!(x, getproperty(y, p), p) end +""" + _checked_rrule +like `rrule` but throws an error if the `rrule` is not defined. +Rather than returning `nothing` +""" function _checked_rrule(f, args...; kwargs...) r = rrule(f, args...; kwargs...) r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 7f2b5fc4e..57efe8d22 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -4,15 +4,19 @@ function rrule(::typeof(map), f, xs...) y = map(f, xs...) - ∂xs = ntuple(length(xs)) do i - Rule() do ȳ - map(ȳ, xs...) do ȳi, xis... - _, ∂xis = _checked_rrule(f, xis...) - extern(∂xis[i](ȳi)) + function map_pullback(ȳ) + ntuple(length(xs)+2) do full_i + full_i == 1 && return NO_FIELDS + full_i == 2 && return DNE() + i = full_i-2 + @thunk map(ȳ, xs...) do ȳi, xis... + _, pullback = _checked_rrule(f, xis...) + ∂xis = pullback(ȳi) + extern(∂xis[i+1]) #+1 to skp ∂self end end end - return y, (NO_FIELDS, DNERule(), ∂xs...) + return y, map_pullback end ##### @@ -26,15 +30,17 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr) insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:)))) insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims))) end + pullback_name = Symbol(mf, :_pullback) body = quote y = $call - ∂x = Rule() do ȳ - broadcast(x, ȳ) do xi, ȳi + function $pullback_name(ȳ) + ∂x = @thunk broadcast(x, ȳ) do xi, ȳi _, ∂xi = _checked_rrule(f, xi) extern(∂xi(ȳi)) end + (NO_FIELDS, DNERule(), DNERule(), ∂x) end - return y, (NO_FIELDS, DNERule(), DNERule(), ∂x) + return y, $pullback_name end eval(Expr(:function, sig, body)) end @@ -43,23 +49,30 @@ end ##### `sum` ##### -frule(::typeof(sum), x) = (sum(x), (ZERO_RULE, Rule(sum))) +frule(::typeof(sum), x) = (sum(x), (_, ẋ)->sum(ẋ)) -rrule(::typeof(sum), x) = (sum(x), (NO_FIELDS, Rule(cast))) +rrule(::typeof(sum), x) = (sum(x), ȳ->(NO_FIELDS, cast(ȳ))) function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) - y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims) - return y, (NO_FIELDS, DNERule(), ∂x) + y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims) + function sum_pullback(ȳ) + NO_FIELDS, DNERule(), @thunk(last(mr_pullback(ȳ))) + end + return y, sum_pullback end function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:) - y, (no_fields, _, ∂x) = rrule(sum, identity, x; dims=dims) - @assert(no_fields === NO_FIELDS) - return y, (NO_FIELDS, ∂x) + y, inner_pullback = rrule(sum, identity, x; dims=dims) + function sum_pullback(ȳ) + NO_FIELDS, DNERule(), @thunk(last(inner_pullback(ȳ))) + end + return y, sum_pullback end function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) y = sum(abs2, x; dims=dims) - ∂x = Rule(ȳ -> 2ȳ .* x) - return y, (NO_FIELDS, DNERule(), ∂x) + function sum_abs2_pullback(ȳ) + (NO_FIELDS, DNERule(), @thunk(2ȳ .* x)) + end + return y, sum_abs2_pullback end diff --git a/test/test_util.jl b/test/test_util.jl index e8f340180..faab75871 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -164,7 +164,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm ∂s = pullback(ȳ) ∂self = ∂s[1] x̄s_ad = ∂s[2:end] - @test self_rule === NO_FIELDS + @test ∂self === NO_FIELDS # Correctness testing via finite differencing. x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing) @@ -178,10 +178,10 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm end # Assuming the above to be correct, check that other ChainRules mechanisms are correct. - for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad) + for (x̄, x̄_ad) in zip(x̄s, x̄s_ad) x̄ === nothing && continue - test_accumulation(x̄, x̄_ad) - test_accumulation(Zero(), x̄_ad) + test_accumulation(x̄, x̄_ad, x̄_ad) + test_accumulation(Zero(), x̄_ad, x̄_ad) end end From 9849149ea2534ee954de58eefd830c674ff3f6c5 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 18:16:24 +0100 Subject: [PATCH 14/38] make structured and dense rulesets pass --- src/ChainRules.jl | 6 ++--- src/rulesets/Base/mapreduce.jl | 15 ++++++------ src/rulesets/LinearAlgebra/dense.jl | 19 +++++++++++++--- src/rulesets/LinearAlgebra/structured.jl | 29 ++++++++++++------------ test/rulesets/Base/mapreduce.jl | 9 ++++---- test/rulesets/LinearAlgebra/dense.jl | 23 +++++++++++-------- test/test_util.jl | 9 +++++++- 7 files changed, 68 insertions(+), 42 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 61f502e34..c7afefb90 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -32,11 +32,11 @@ include("rulesets/Base/mapreduce.jl") #include("rulesets/Statistics/statistics.jl") -#include("rulesets/LinearAlgebra/utils.jl") +include("rulesets/LinearAlgebra/utils.jl") #include("rulesets/LinearAlgebra/blas.jl") include("rulesets/LinearAlgebra/dense.jl") -#include("rulesets/LinearAlgebra/structured.jl") -#include("rulesets/LinearAlgebra/factorization.jl") +include("rulesets/LinearAlgebra/structured.jl") +include("rulesets/LinearAlgebra/factorization.jl") # Note: The following is only required because package authors sometimes do not # declare their own rules using `ChainRulesCore.jl`. For arguably good reasons. diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 57efe8d22..4d7dfa73e 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -24,7 +24,7 @@ end ##### for mf in (:mapreduce, :mapfoldl, :mapfoldr) - sig = :(rrule(::typeof($mf), f, op, x::AbstractArray{<:Real})) + sig = :(ChainRulesCore.rrule(::typeof($mf), f, op, x::AbstractArray{<:Real})) call = :($mf(f, op, x)) if mf === :mapreduce insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:)))) @@ -35,10 +35,11 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr) y = $call function $pullback_name(ȳ) ∂x = @thunk broadcast(x, ȳ) do xi, ȳi - _, ∂xi = _checked_rrule(f, xi) - extern(∂xi(ȳi)) + _, pullback_f = _checked_rrule(f, xi) + _, ∂xi = pullback_f(ȳi) + extern(∂xi) end - (NO_FIELDS, DNERule(), DNERule(), ∂x) + (NO_FIELDS, DNE(), DNE(), ∂x) end return y, $pullback_name end @@ -56,7 +57,7 @@ rrule(::typeof(sum), x) = (sum(x), ȳ->(NO_FIELDS, cast(ȳ))) function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims) function sum_pullback(ȳ) - NO_FIELDS, DNERule(), @thunk(last(mr_pullback(ȳ))) + NO_FIELDS, DNE(), last(mr_pullback(ȳ)) end return y, sum_pullback end @@ -64,7 +65,7 @@ end function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:) y, inner_pullback = rrule(sum, identity, x; dims=dims) function sum_pullback(ȳ) - NO_FIELDS, DNERule(), @thunk(last(inner_pullback(ȳ))) + NO_FIELDS, last(inner_pullback(ȳ)) end return y, sum_pullback end @@ -72,7 +73,7 @@ end function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) y = sum(abs2, x; dims=dims) function sum_abs2_pullback(ȳ) - (NO_FIELDS, DNERule(), @thunk(2ȳ .* x)) + (NO_FIELDS, DNE(), @thunk(2ȳ .* x)) end return y, sum_abs2_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 42e8c22cf..0c4ccbbee 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -102,17 +102,30 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr return Y, slash_pullback end +function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) + Aᵀ, dA = rrule(adjoint, A) + Bᵀ, dB = rrule(adjoint, B) + Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ) + C, dC = rrule(adjoint, Cᵀ) + ∂A = Rule(dA∘dAᵀ∘dC) + ∂B = Rule(dA∘dBᵀ∘dC) + return C, (∂A, ∂B) +end function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) Aᵀ, dA_pb = rrule(adjoint, A) Bᵀ, dB_pb = rrule(adjoint, B) Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) C, dC_pb = rrule(adjoint, Cᵀ) function slash_pullback(Ȳ) + # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want + # this is not a problem if you want the 2nd or 3rd, but if you want the first, it + # is fairly wasteful _, dC = dC_pb(Ȳ) - _, dAᵀ, dBᵀ = dS_pb(dC) + _, dBᵀ, dAᵀ = dS_pb(extern(dC)) - ∂A = @thunk last(dA_pb(dAᵀ)) - ∂B = @thunk last(dA_pb(dBᵀ)) + # need to extern as dAᵀ, dBᵀ are generally `Thunk`s, which don't support adjoint + ∂A = @thunk last(dA_pb(extern(dAᵀ))) + ∂B = @thunk last(dA_pb(extern(dBᵀ))) (NO_FIELDS, ∂A, ∂B) end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 915cacca9..591e9e138 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -1,20 +1,21 @@ # Structured matrices + ##### ##### `Diagonal` ##### -rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), (NO_FIELDS, Rule(diag)) +rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), ȳ->(NO_FIELDS, diag(ȳ)) -rrule(::typeof(diag), A::AbstractMatrix) = diag(A), (NO_FIELDS, Rule(Diagonal)) +rrule(::typeof(diag), A::AbstractMatrix) = diag(A), ȳ->(NO_FIELDS, Diagonal(ȳ)) ##### ##### `Symmetric` ##### -rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), (NO_FIELDS, Rule(_symmetric_back)) +rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), ȳ->(NO_FIELDS, _symmetric_back(ȳ)) -_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) +_symmetric_back(ΔΩ) = @thunk(UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)) _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### @@ -22,26 +23,26 @@ _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### # TODO: Deal with complex-valued arrays as well -rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), (NO_FIELDS, Rule(adjoint)) -rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), (NO_FIELDS, Rule(vec∘adjoint)) +rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), ȳ->(NO_FIELDS, adjoint(ȳ)) +rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), ȳ->(NO_FIELDS, vec(adjoint(ȳ))) -rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), (NO_FIELDS, Rule(adjoint)) -rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), (NO_FIELDS, Rule(vec∘adjoint)) +rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), ȳ->(NO_FIELDS, adjoint(ȳ)) +rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), ȳ->(NO_FIELDS, vec(adjoint(ȳ))) ##### ##### `Transpose` ##### -rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), (NO_FIELDS, Rule(transpose)) -rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), (NO_FIELDS, Rule(vec∘transpose)) +rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), ȳ->(NO_FIELDS, transpose(ȳ)) +rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), ȳ->(NO_FIELDS, vec(transpose(ȳ))) -rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), (NO_FIELDS, Rule(transpose)) -rrule(::typeof(transpose), A::AbstractVector) = transpose(A), (NO_FIELDS, Rule(vec∘transpose)) +rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), ȳ->(NO_FIELDS, transpose(ȳ)) +rrule(::typeof(transpose), A::AbstractVector) = transpose(A), ȳ->(NO_FIELDS, vec(transpose(ȳ))) ##### ##### Triangular matrices ##### -rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), (NO_FIELDS, Rule(Matrix)) +rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), ȳ->(NO_FIELDS, Matrix(ȳ)) -rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), (NO_FIELDS, Rule(Matrix)) +rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), ȳ->(NO_FIELDS, Matrix(ȳ)) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 6c3892067..5efc422d2 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -15,11 +15,12 @@ vx = randn(rng, n) ȳ = randn(rng) rrule_test(mapreduce, ȳ, (sin, nothing), (+, nothing), (x, vx)) + # With keyword arguments (not yet supported in rrule_test) X = randn(rng, n, n) - y, (_, _, dx) = rrule(mapreduce, abs2, +, X; dims=2) + y, pullback = rrule(mapreduce, abs2, +, X; dims=2) ȳ = randn(rng, size(y)) - x̄_ad = dx(ȳ) + (_, _, _, x̄_ad) = pullback(ȳ) x̄_fd = j′vp(central_fdm(5, 1), x->mapreduce(abs2, +, x; dims=2), ȳ, X) @test x̄_ad ≈ x̄_fd atol=1e-9 rtol=1e-9 end @@ -57,9 +58,9 @@ rng = MersenneTwister(33) n = 4 X = randn(rng, n, n) - y, dX = rrule(sum, X; dims=2) + y, pullback = rrule(sum, X; dims=2) ȳ = randn(rng, size(y)) - x̄_ad = dX(ȳ) + _, x̄_ad = pullback(ȳ) x̄_fd = j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X) @test x̄_ad ≈ x̄_fd atol=1e-9 rtol=1e-9 end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index dcc861b5b..beacac383 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -66,17 +66,20 @@ end end @testset "$f" for f in [/, \] rng = MersenneTwister(42) - for n in 3:5, m in 3:5 - A = randn(rng, m, n) - B = randn(rng, m, n) - Ȳ = randn(rng, size(f(A, B))) - rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n))) + @testset "Matrix" begin + for n in 3:5, m in 3:5 + A = randn(rng, m, n) + B = randn(rng, m, n) + Ȳ = randn(rng, size(f(A, B))) + rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n))) + end + end + @testset "Vector" begin + x = randn(rng, 10) + y = randn(rng, 10) + ȳ = randn(rng, size(f(x, y))...) + rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10))) end - # Vectors - x = randn(rng, 10) - y = randn(rng, 10) - ȳ = randn(rng, size(f(x, y))...) - rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10))) if f == (/) @testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular) RHS = T(randn(rng, T == Diagonal ? 10 : (10, 10))) diff --git a/test/test_util.jl b/test/test_util.jl index faab75871..6f9adde54 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -5,6 +5,7 @@ using ChainRulesCore: AbstractDifferential const _fdm = central_fdm(5, 1) + """ test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...) @@ -63,6 +64,10 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa end function ensure_not_running_on_functor(f, name) + # if x itself is a Type, then it is a constructor, thus not a functor. + # This also catchs UnionAll constructors which have a `:var` and `:body` fields + f isa Type && return + if fieldcount(typeof(f)) > 0 throw(ArgumentError( "$name cannot be used on closures/functors (such as $f)" @@ -161,6 +166,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm y, pullback = rrule(f, xs...) @test f(xs...) == y + @assert !(isa(ȳ, Thunk)) ∂s = pullback(ȳ) ∂self = ∂s[1] x̄s_ad = ∂s[2:end] @@ -173,7 +179,8 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm # The way we've structured the above, this tests that the rule is a DNERule @test x̄_ad isa DNE else - @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) + # TODO: remove extern from the line below, it is just there to make test output readign easier for now + @test isapprox(extern(x̄_ad), x̄_fd; rtol=rtol, atol=atol, kwargs...) end end From 908b7ea1b98505018a1c09956c006e404e72c1db Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 19:41:44 +0100 Subject: [PATCH 15/38] BLAS written but need to re-sort out update rules before done proper --- src/rulesets/LinearAlgebra/blas.jl | 145 +++++++++++++++++++---------- 1 file changed, 97 insertions(+), 48 deletions(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index d2edc7063..ad1ae9969 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -7,8 +7,6 @@ using LinearAlgebra: BlasFloat _zeros(x) = fill!(similar(x), zero(eltype(x))) -_rule_via(∂) = Rule(ΔΩ -> isa(ΔΩ, Zero) ? ΔΩ : ∂(extern(ΔΩ))) - ##### ##### `BLAS.dot` ##### @@ -19,9 +17,18 @@ rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y) function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) Ω = BLAS.dot(n, X, incx, Y, incy) - ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) - ∂Y = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) - return Ω, (NO_FIELDS, DNERule(), _rule_via(∂X), DNERule(), _rule_via(∂Y), DNERule()) + function blas_dot_pullback(ΔΩ) + if ΔΩ isa Zero + ∂X = Zero() + ∂Y = Zero() + else + ΔΩ = extern(ΔΩ) + ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) + ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) + end + return (NO_FIELDS, DNE(), ∂X, DNE(), ∂Y, DNE()) + end + return Ω, blas_dot_pullback end ##### @@ -30,18 +37,33 @@ end function frule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) - return Ω, (ZERO_RULE, Rule(Δx -> sum(Δx * cast(@thunk(x * inv(Ω)))))) + function nrm2_pushforward(_, Δx) + return sum(Δx * cast(@thunk(x * inv(Ω)))) + end + return Ω, nrm2_pushforward end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) - return Ω, (NO_FIELDS, Rule(ΔΩ -> ΔΩ * @thunk(x * inv(Ω)))) + function nrm2_pullback(ΔΩ) + return NO_FIELDS, @thunk(ΔΩ * @thunk(x * inv(Ω))) + end + return Ω, nrm2_pullback end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) - ∂X = ΔΩ -> scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) - return Ω, (NO_FIELDS, DNERule(), _rule_via(∂X), DNERule()) + function nrm2_pullback(ΔΩ) + if ΔΩ isa Zero + ∂X = Zero() + else + ΔΩ = extern(ΔΩ) + ∂X = scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) + end + return (NO_FIELDS, DNE(), ∂X, DNE()) + end + + return Ω, nrm2_pullback end ##### @@ -49,17 +71,29 @@ end ##### function frule(::typeof(BLAS.asum), x) - return BLAS.asum(x), (ZERO_RULE, Rule(Δx -> sum(cast(sign, x) * Δx))) + return BLAS.asum(x), (_, Δx) -> sum(cast(sign, x) * Δx) end function rrule(::typeof(BLAS.asum), x) - return BLAS.asum(x), (NO_FIELDS, Rule(ΔΩ -> ΔΩ * cast(sign, x))) + return BLAS.asum(x), ΔΩ -> (NO_FIELDS, @thunk(ΔΩ * cast(sign, x))) end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) - ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, sign.(X), incx, _zeros(X), incx), incx) - return Ω, (NO_FIELDS, DNERule(), _rule_via(∂X), DNERule()) + function asum_pullback(ΔΩ) + if ΔΩ isa Zero + ∂X = Zero() + else + ΔΩ = extern(ΔΩ) + ∂X = @thunk scal!( + n, ΔΩ, + blascopy!(n, sign.(X), incx, _zeros(X), incx), + incx + ) + end + return (NO_FIELDS, DNE(), ∂X, DNE()) + end + return Ω, asum_pullback end ##### @@ -69,20 +103,27 @@ end function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat y = gemv(tA, α, A, x) - if uppercase(tA) === 'N' - ∂A = Rule(ȳ -> α * ȳ * x', (Ā, ȳ) -> ger!(α, ȳ, x, Ā)) - ∂x = Rule(ȳ -> gemv('T', α, A, ȳ), (x̄, ȳ) -> gemv!('T', α, A, ȳ, one(T), x̄)) - else - ∂A = Rule(ȳ -> α * x * ȳ', (Ā, ȳ) -> ger!(α, x, ȳ, Ā)) - ∂x = Rule(ȳ -> gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄)) + function gemv_pullback(ȳ) + if uppercase(tA) === 'N' + ∂A = @thunk(α * ȳ * x', (Ā, ȳ) -> ger!(α, ȳ, x, Ā)) + ∂x = @thunk(gemv('T', α, A, ȳ), (x̄, ȳ) -> gemv!('T', α, A, ȳ, one(T), x̄)) + else + ∂A = @thunk(α * x * ȳ', (Ā, ȳ) -> ger!(α, x, ȳ, Ā)) + ∂x = @thunk(gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄)) + end + return (NO_FIELDS, DNE(), @thunk(dot(ȳ, y) / α), ∂A, ∂x) end - return y, (NO_FIELDS, DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x) + return y, gemv_pullback end function rrule(::typeof(gemv), tA::Char, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat - y, (dtA, _, dA, dx) = rrule(gemv, tA, one(T), A, x) - return y, (NO_FIELDS, dtA, dA, dx) + y, inner_pullback = rrule(gemv, tA, one(T), A, x) + function gemv_pullback(Ȳ) + (_, dtA, _, dA, dx) = inner_pullback(Ȳ) + return (NO_FIELDS, dtA, dA, dx) + end + return y, gemv_pullback end ##### @@ -92,37 +133,45 @@ end function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat C = gemm(tA, tB, α, A, B) - β = one(T) - if uppercase(tA) === 'N' - if uppercase(tB) === 'N' - ∂A = Rule(C̄ -> gemm('N', 'T', α, C̄, B), - (Ā, C̄) -> gemm!('N', 'T', α, C̄, B, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'N', α, A, C̄), - (B̄, C̄) -> gemm!('T', 'N', α, A, C̄, β, B̄)) - else - ∂A = Rule(C̄ -> gemm('N', 'N', α, C̄, B), - (Ā, C̄) -> gemm!('N', 'N', α, C̄, B, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'N', α, C̄, A), - (B̄, C̄) -> gemm!('T', 'N', α, C̄, A, β, B̄)) - end - else - if uppercase(tB) === 'N' - ∂A = Rule(C̄ -> gemm('N', 'T', α, B, C̄), - (Ā, C̄) -> gemm!('N', 'T', α, B, C̄, β, Ā)) - ∂B = Rule(C̄ -> gemm('N', 'N', α, A, C̄), - (B̄, C̄) -> gemm!('N', 'N', α, A, C̄, β, B̄)) + function gemv_pullback(C̄) + β = one(T) + if uppercase(tA) === 'N' + if uppercase(tB) === 'N' + ∂A = @thunk(gemm('N', 'T', α, C̄, B)) + ∂A_update = Ā -> gemm!('N', 'T', α, C̄, B, β, Ā) + ∂B = @thunk(gemm('T', 'N', α, A, C̄)) + ∂B_update = B̄ -> gemm!('T', 'N', α, A, C̄, β, B̄) + else + ∂A = @thunk(gemm('N', 'N', α, C̄, B)) + ∂A_update = Ā -> gemm!('N', 'N', α, C̄, B, β, Ā) + ∂B = @thunk(gemm('T', 'N', α, C̄, A)) + ∂B_update = B̄ -> gemm!('T', 'N', α, C̄, A, β, B̄) + end else - ∂A = Rule(C̄ -> gemm('T', 'T', α, B, C̄), - (Ā, C̄) -> gemm!('T', 'T', α, B, C̄, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'T', α, C̄, A), - (B̄, C̄) -> gemm!('T', 'T', α, C̄, A, β, B̄)) + if uppercase(tB) === 'N' + ∂A = @thunk(gemm('N', 'T', α, B, C̄)) + ∂A_update = Ā -> gemm!('N', 'T', α, B, C̄, β, Ā) + ∂B = @thunk(gemm('N', 'N', α, A, C̄)) + ∂B_update = B̄ -> gemm!('N', 'N', α, A, C̄, β, B̄) + else + ∂A = @thunk(gemm('T', 'T', α, B, C̄)) + ∂A_update = Ā -> gemm!('T', 'T', α, B, C̄, β, Ā) + ∂B = @thunk(gemm('T', 'T', α, C̄, A)) + ∂A_update = B̄ -> gemm!('T', 'T', α, C̄, A, β, B̄) + end end + # TODO: ∂A_update and ∂B_update. Requires working out update rules in the post #30 world + return (NO_FIELDS, DNE(), DNE(), @thunk(dot(C̄, C) / α), ∂A, ∂B) end - return C, (NO_FIELDS, DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B) + return C, gemv_pullback end function rrule(::typeof(gemm), tA::Char, tB::Char, A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat - C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B) - return C, (NO_FIELDS, dtA, dtB, dA, dB) + C, inner_pullback = rrule(gemm, tA, tB, one(T), A, B) + function gemv_pullback(Ȳ) + (_, dtA, dtB, _, dA, dB) = inner_pullback(Ȳ) + return (NO_FIELDS, dtA, dtB, dA, dB) + end + return C, gemm_pullback end From 68c89e286724ae74b5b07ef41ac9648a27fb32d9 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 10 Sep 2019 11:46:44 +0100 Subject: [PATCH 16/38] BLAS rules working but update accumulation inplace is diabled --- src/ChainRules.jl | 2 +- src/rulesets/LinearAlgebra/blas.jl | 15 ++++++++++----- src/rulesets/LinearAlgebra/dense.jl | 9 --------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index c7afefb90..5c811e127 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -33,7 +33,7 @@ include("rulesets/Base/mapreduce.jl") #include("rulesets/Statistics/statistics.jl") include("rulesets/LinearAlgebra/utils.jl") -#include("rulesets/LinearAlgebra/blas.jl") +include("rulesets/LinearAlgebra/blas.jl") include("rulesets/LinearAlgebra/dense.jl") include("rulesets/LinearAlgebra/structured.jl") include("rulesets/LinearAlgebra/factorization.jl") diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index ad1ae9969..18bf2f4d4 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -104,12 +104,17 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat y = gemv(tA, α, A, x) function gemv_pullback(ȳ) + # TODO: make use of update rules if uppercase(tA) === 'N' - ∂A = @thunk(α * ȳ * x', (Ā, ȳ) -> ger!(α, ȳ, x, Ā)) - ∂x = @thunk(gemv('T', α, A, ȳ), (x̄, ȳ) -> gemv!('T', α, A, ȳ, one(T), x̄)) + ∂A = @thunk(α * ȳ * x') + ∂A_update = Ā -> ger!(α, ȳ, x, Ā) + ∂x = @thunk(gemv('T', α, A, ȳ)) + ∂x_update = x̄ -> gemv!('T', α, A, ȳ, one(T), x̄) else - ∂A = @thunk(α * x * ȳ', (Ā, ȳ) -> ger!(α, x, ȳ, Ā)) - ∂x = @thunk(gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄)) + ∂A = @thunk(α * x * ȳ') + ∂A_update = Ā -> ger!(α, x, ȳ, Ā) + ∂x = @thunk(gemv('N', α, A, ȳ)) + ∂x_update = x̄ -> gemv!('N', α, A, ȳ, one(T), x̄) end return (NO_FIELDS, DNE(), @thunk(dot(ȳ, y) / α), ∂A, ∂x) end @@ -160,7 +165,7 @@ function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, ∂A_update = B̄ -> gemm!('T', 'T', α, C̄, A, β, B̄) end end - # TODO: ∂A_update and ∂B_update. Requires working out update rules in the post #30 world + # TODO: use ∂A_update and ∂B_update. Requires working out update rules in the post #30 world return (NO_FIELDS, DNE(), DNE(), @thunk(dot(C̄, C) / α), ∂A, ∂B) end return C, gemv_pullback diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 0c4ccbbee..3a0aa4d09 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -102,15 +102,6 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr return Y, slash_pullback end -function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Aᵀ, dA = rrule(adjoint, A) - Bᵀ, dB = rrule(adjoint, B) - Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ) - C, dC = rrule(adjoint, Cᵀ) - ∂A = Rule(dA∘dAᵀ∘dC) - ∂B = Rule(dA∘dBᵀ∘dC) - return C, (∂A, ∂B) -end function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) Aᵀ, dA_pb = rrule(adjoint, A) Bᵀ, dB_pb = rrule(adjoint, B) From 33ae54e5f25abbfaf57c1f0eca427103093ecaa9 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 10 Sep 2019 14:03:12 +0100 Subject: [PATCH 17/38] Factorizations working --- src/rulesets/LinearAlgebra/factorization.jl | 68 ++++++++++++-------- test/rulesets/LinearAlgebra/factorization.jl | 41 ++++++++---- 2 files changed, 68 insertions(+), 41 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index cf89e9351..661b69faa 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -7,25 +7,31 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) F = svd(X) - ∂X = Rule() do Ȳ::NamedTuple{(:U,:S,:V)} - svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) + function svd_pullback(Ȳ::NamedTuple{(:U,:S,:V)}) + ∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)) + return (NO_FIELDS, ∂X) end - return F, (NO_FIELDS, ∂X) + return F, svd_pullback end function rrule(::typeof(getproperty), F::SVD, x::Symbol) - if x === :U - rule = Ȳ->(U=Ȳ, S=zero(F.S), V=zero(F.V)) - elseif x === :S - rule = Ȳ->(U=zero(F.U), S=Ȳ, V=zero(F.V)) - elseif x === :V - rule = Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ) - elseif x === :Vt - # TODO: This could be made to work, but it'd be a pain - throw(ArgumentError("Vt is unsupported; use V and transpose the result")) + function getproperty_svd_pullback(Ȳ) + if x === :U + ∂ = @thunk((; U=Ȳ, S=(zero(F.S)), V=(zero(F.V)))) + elseif x === :S + ∂ = @thunk((; U=(zero(F.U)), S=Ȳ, V=(zero(F.V)))) + elseif x === :V + ∂ = @thunk((; U=(zero(F.U)), S=(zero(F.S)), V=Ȳ)) + elseif x === :Vt + # TODO: This could be made to work, but it'd be a pain + throw(ArgumentError("Vt is unsupported; use V and transpose the result")) + end + # TODO use update + update = (X̄::NamedTuple{(:U,:S,:V)}) -> _update!(X̄, ∂, x) + + return NO_FIELDS, ∂, DNE() end - update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x) - return getproperty(F, x), (NO_FIELDS, Rule(rule, update), DNERule()) + return getproperty(F, x), getproperty_svd_pullback end function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) @@ -65,25 +71,31 @@ end function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) F = cholesky(X) - ∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) - return F, (NO_FIELDS, ∂X) + function cholesky_pullback(Ȳ) + ∂X = @thunk(chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) + return (NO_FIELDS, ∂X) + end + return F, cholesky_pullback end function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) - if x === :U - if F.uplo === 'U' - ∂F = Ȳ->UpperTriangular(Ȳ) - else - ∂F = Ȳ->LowerTriangular(Ȳ') - end - elseif x === :L - if F.uplo === 'L' - ∂F = Ȳ->LowerTriangular(Ȳ) - else - ∂F = Ȳ->UpperTriangular(Ȳ') + function getproperty_cholesky_pullback(Ȳ) + if x === :U + if F.uplo === 'U' + ∂F = @thunk UpperTriangular(Ȳ) + else + ∂F = @thunk LowerTriangular(Ȳ') + end + elseif x === :L + if F.uplo === 'L' + ∂F = @thunk LowerTriangular(Ȳ) + else + ∂F = @thunk UpperTriangular(Ȳ') + end end + return NO_FIELDS, ∂F, DNE() end - return getproperty(F, x), (NO_FIELDS, Rule(∂F), DNERule()) + return getproperty(F, x), getproperty_cholesky_pullback end # See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index a8d925f23..f2364c994 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -2,21 +2,32 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @testset "Factorizations" begin @testset "svd" begin - rng = MersenneTwister(2) + rng = MersenneTwister(3) for n in [4, 6, 10], m in [3, 5, 10] X = randn(rng, n, m) - F, dX = rrule(svd, X) + F, dX_pullback = rrule(svd, X) for p in [:U, :S, :V] - Y, (dF, dp) = rrule(getproperty, F, p) - @test dp isa ChainRules.DNERule + Y, dF_pullback = rrule(getproperty, F, p) Ȳ = randn(rng, size(Y)...) - X̄_ad = dX(dF(Ȳ)) + + dself1, dF, dp = dF_pullback(Ȳ) + @test dself1 === NO_FIELDS + @test dp === DNE() + + ΔF = extern(dF) + dself2, dX = dX_pullback(ΔF) + @test dself2 === NO_FIELDS + X̄_ad = extern(dX) X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X) @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 end - @test_throws ArgumentError rrule(getproperty, F, :Vt) + @testset "Vt" begin + Y, dF_pullback = rrule(getproperty, F, :Vt) + Ȳ = randn(rng, size(Y)...) + @test_throws ArgumentError dF_pullback(Ȳ) + end end - #== TODO: re-enable me + #== TODO: re-enable me, once updating rules work @testset "accumulate!" begin X = [1.0 2.0; 3.0 4.0; 5.0 6.0] F, dX = rrule(svd, X) @@ -44,18 +55,22 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @testset "the thing" begin X = generate_well_conditioned_matrix(rng, 10) V = generate_well_conditioned_matrix(rng, 10) - F, dX = rrule(cholesky, X) + F, dX_pullback = rrule(cholesky, X) for p in [:U, :L] - Y, (dself, dF, dp) = rrule(getproperty, F, p) - @test dself === NO_FIELDS - @test dp isa ChainRules.DNERule + Y, dF_pullback = rrule(getproperty, F, p) Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y))) + (dself, dF, dp) = dF_pullback(Ȳ) + @test dself === NO_FIELDS + @test dp === DNE() + # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` # machinery from FiniteDifferences because that isn't set up to respect # necessary special properties of the input. In the case of the Cholesky # factorization, we need the input to be Hermitian. - X̄_ad = dot(dX(dF(Ȳ)), V) - X̄_fd = central_fdm(5, 1)() do ε + ΔF = extern(dF) + _, dX = dX_pullback(ΔF) + X̄_ad = dot(extern(dX), V) + X̄_fd = central_fdm(5,1)() do ε dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) end @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 From dff859adbd67b9c4cf79d94781e093cfb3e2b94b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 10 Sep 2019 14:41:32 +0100 Subject: [PATCH 18/38] Make statistics work --- src/ChainRules.jl | 3 +-- src/rulesets/Statistics/statistics.jl | 22 +++++++++++++--- test/rulesets/Statistics/statistics.jl | 36 +++++++++++++++++++++----- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 5c811e127..d430c36bd 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -29,8 +29,7 @@ include("rulesets/Base/array.jl") include("rulesets/Base/broadcast.jl") include("rulesets/Base/mapreduce.jl") - -#include("rulesets/Statistics/statistics.jl") +include("rulesets/Statistics/statistics.jl") include("rulesets/LinearAlgebra/utils.jl") include("rulesets/LinearAlgebra/blas.jl") diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 2be434ce3..8b96e6027 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -9,13 +9,27 @@ _denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1) # TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36 function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) - _, dx = rrule(sum, x; dims=dims) + y_sum, sum_pullback = rrule(sum, x; dims=dims) n = _denom(x, dims) - return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n) + function mean_pullback(ȳ) + ∂x = Thunk() do + _, ∂sum_x = sum_pullback(ȳ) + extern(∂sum_x)/n + end + return (NO_FIELDS, ∂x) + end + return y_sum/n, mean_pullback end function rrule(::typeof(mean), f, x::AbstractArray{<:Real}) - _, (_, dx) = rrule(sum, f, x) + y_sum, sum_pullback = rrule(sum, f, x) n = _denom(x, :) - return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n)) + function mean_pullback(ȳ) + ∂x = Thunk() do + _, _, ∂sum_x = sum_pullback(ȳ) + extern(∂sum_x)/n + end + return (NO_FIELDS, DNE(), ∂x) + end + return y_sum/n, mean_pullback end diff --git a/test/rulesets/Statistics/statistics.jl b/test/rulesets/Statistics/statistics.jl index b36d5fbe9..5ddafe0e7 100644 --- a/test/rulesets/Statistics/statistics.jl +++ b/test/rulesets/Statistics/statistics.jl @@ -1,11 +1,33 @@ @testset "mean" begin rng = MersenneTwister(999) n = 9 - rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n))) - X = randn(rng, n, n) - y, dX = rrule(mean, X; dims=1) - ȳ = randn(rng, size(y)) - X̄_ad = dX(ȳ) - X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X) - @test X̄_ad ≈ X̄_fd rtol=1e-9 atol=1e-9 + + @testset "Basic" begin + rrule_test( + mean, + randn(rng), + (randn(rng, n), + randn(rng, n)) + ) + end + + @testset "with function arg" begin + rrule_test( + mean, + randn(rng), + (abs2, nothing), + (randn(rng, n), + randn(rng, n)) + ) + end + + @testset "with dims kwargs" begin + X = randn(rng, n, n) + y, mean_pullback = rrule(mean, X; dims=1) + ȳ = randn(rng, size(y)) + _, dX = mean_pullback(ȳ) + X̄_ad = extern(dX) + X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X) + @test X̄_ad ≈ X̄_fd rtol=1e-9 atol=1e-9 + end end From a434f2a970e977ac3660a953120df78f813050e4 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 10 Sep 2019 15:20:50 +0100 Subject: [PATCH 19/38] remove double extern --- test/rulesets/Base/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 4dd4593ae..9dc592008 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -7,7 +7,7 @@ (dself, dsin, dx) = pullback(One()) @test dself == NO_FIELDS @test dsin == DNE() - @test extern(extern(dx)) == cos.(x) + @test extern(dx) == cos.(x) x̄, ȳ = rand(), rand() ∂x = pullback(ȳ)[3] From aafadfabe4082b330f609c2c146ff40065951041 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 11 Sep 2019 12:54:57 +0100 Subject: [PATCH 20/38] fix bad rebase --- test/rulesets/Base/base.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 9f90fa180..d570e33f4 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -127,8 +127,7 @@ test_scalar(conj, x; rtol=rtol) end end - - #== TODO Renable me + @testset "*(x, y)" begin x, y = rand(3, 2), rand(2, 5) z, pullback = rrule(*, x, y) From 4b7132574b8dbfd2a1e9963bc3f670434abcb95b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 11 Sep 2019 17:12:51 +0100 Subject: [PATCH 21/38] WIP use InplaceableThunks for updating rules --- src/rulesets/LinearAlgebra/blas.jl | 74 +++++++++++++------- src/rulesets/LinearAlgebra/dense.jl | 2 +- src/rulesets/LinearAlgebra/factorization.jl | 6 +- src/rulesets/LinearAlgebra/structured.jl | 2 +- test/rulesets/LinearAlgebra/factorization.jl | 4 +- test/test_util.jl | 8 ++- 6 files changed, 60 insertions(+), 36 deletions(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 18bf2f4d4..4583a7066 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -104,17 +104,24 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat y = gemv(tA, α, A, x) function gemv_pullback(ȳ) - # TODO: make use of update rules if uppercase(tA) === 'N' - ∂A = @thunk(α * ȳ * x') - ∂A_update = Ā -> ger!(α, ȳ, x, Ā) - ∂x = @thunk(gemv('T', α, A, ȳ)) - ∂x_update = x̄ -> gemv!('T', α, A, ȳ, one(T), x̄) + ∂A = InplaceableThunk( + @thunk(α * ȳ * x'), + Ā -> ger!(α, ȳ, x, Ā) + ) + ∂x = InplaceableThunk( + @thunk(gemv('T', α, A, ȳ)), + x̄ -> gemv!('T', α, A, ȳ, one(T), x̄) + ) else - ∂A = @thunk(α * x * ȳ') - ∂A_update = Ā -> ger!(α, x, ȳ, Ā) - ∂x = @thunk(gemv('N', α, A, ȳ)) - ∂x_update = x̄ -> gemv!('N', α, A, ȳ, one(T), x̄) + ∂A = InplaceableThunk( + @thunk(α * x * ȳ), + Ā -> ger!(α, x, ȳ, Ā) + ) + ∂x = InplaceableThunk( + @thunk(gemv('N', α, A, ȳ)), + x̄ -> gemv!('N', α, A, ȳ, one(T), x̄) + ) end return (NO_FIELDS, DNE(), @thunk(dot(ȳ, y) / α), ∂A, ∂x) end @@ -142,30 +149,45 @@ function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, β = one(T) if uppercase(tA) === 'N' if uppercase(tB) === 'N' - ∂A = @thunk(gemm('N', 'T', α, C̄, B)) - ∂A_update = Ā -> gemm!('N', 'T', α, C̄, B, β, Ā) - ∂B = @thunk(gemm('T', 'N', α, A, C̄)) - ∂B_update = B̄ -> gemm!('T', 'N', α, A, C̄, β, B̄) + ∂A = InplaceableThunk( + @thunk(gemm('N', 'T', α, C̄, B)), + Ā -> gemm!('N', 'T', α, C̄, B, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('T', 'N', α, A, C̄)), + B̄ -> gemm!('T', 'N', α, A, C̄, β, B̄) + ) else - ∂A = @thunk(gemm('N', 'N', α, C̄, B)) - ∂A_update = Ā -> gemm!('N', 'N', α, C̄, B, β, Ā) - ∂B = @thunk(gemm('T', 'N', α, C̄, A)) - ∂B_update = B̄ -> gemm!('T', 'N', α, C̄, A, β, B̄) + ∂A = InplaceableThunk( + @thunk(gemm('N', 'N', α, C̄, B)), + Ā -> gemm!('N', 'N', α, C̄, B, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('T', 'N', α, C̄, A)), + B̄ -> gemm!('T', 'N', α, C̄, A, β, B̄) + ) end else if uppercase(tB) === 'N' - ∂A = @thunk(gemm('N', 'T', α, B, C̄)) - ∂A_update = Ā -> gemm!('N', 'T', α, B, C̄, β, Ā) - ∂B = @thunk(gemm('N', 'N', α, A, C̄)) - ∂B_update = B̄ -> gemm!('N', 'N', α, A, C̄, β, B̄) + ∂A = InplaceableThunk( + @thunk(gemm('N', 'T', α, B, C̄)), + Ā -> gemm!('N', 'T', α, B, C̄, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('N', 'N', α, A, C̄)), + B̄ -> gemm!('N', 'N', α, A, C̄, β, B̄) + ) else - ∂A = @thunk(gemm('T', 'T', α, B, C̄)) - ∂A_update = Ā -> gemm!('T', 'T', α, B, C̄, β, Ā) - ∂B = @thunk(gemm('T', 'T', α, C̄, A)) - ∂A_update = B̄ -> gemm!('T', 'T', α, C̄, A, β, B̄) + ∂A = InplaceableThunk( + @thunk(gemm('T', 'T', α, B, C̄)), + Ā -> gemm!('T', 'T', α, B, C̄, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('T', 'T', α, C̄, A)), + B̄ -> gemm!('T', 'T', α, C̄, A, β, B̄) + ) end end - # TODO: use ∂A_update and ∂B_update. Requires working out update rules in the post #30 world return (NO_FIELDS, DNE(), DNE(), @thunk(dot(C̄, C) / α), ∂A, ∂B) end return C, gemv_pullback diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 3a0aa4d09..203fa0af2 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -173,7 +173,7 @@ end function rrule(::typeof(norm), x::Real, p::Real=2) function norm_pullback(ȳ) ∂x = @thunk ȳ * sign(x) - ∂p = @thunk zero(x) #TODO: should this be Zero() + ∂p = @thunk zero(x) #TODO: should this be Zero()? (NO_FIELDS, ∂x, ∂p) end return norm(x, p), norm_pullback diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 661b69faa..6fc69ebee 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -26,10 +26,10 @@ function rrule(::typeof(getproperty), F::SVD, x::Symbol) # TODO: This could be made to work, but it'd be a pain throw(ArgumentError("Vt is unsupported; use V and transpose the result")) end - # TODO use update - update = (X̄::NamedTuple{(:U,:S,:V)}) -> _update!(X̄, ∂, x) - return NO_FIELDS, ∂, DNE() + update = (X̄::NamedTuple{(:U,:S,:V)}) -> _update!(X̄, ∂, x) + ∂F = InplaceableThunk(∂, update) + return NO_FIELDS, ∂F, DNE() end return getproperty(F, x), getproperty_svd_pullback end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 591e9e138..8de6cb0ac 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -22,7 +22,7 @@ _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### `Adjoint` ##### -# TODO: Deal with complex-valued arrays as well +# ✖️✖️✖️TODO: Deal with complex-valued arrays as well rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), ȳ->(NO_FIELDS, adjoint(ȳ)) rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), ȳ->(NO_FIELDS, vec(adjoint(ȳ))) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index f2364c994..ec31389ae 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -27,7 +27,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @test_throws ArgumentError dF_pullback(Ȳ) end end - #== TODO: re-enable me, once updating rules work + @testset "accumulate!" begin X = [1.0 2.0; 3.0 4.0; 5.0 6.0] F, dX = rrule(svd, X) @@ -41,7 +41,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @test X̄.S ≈ ones(2) atol=1e-6 @test X̄.V ≈ ones(2, 2) atol=1e-6 end - ==# + @testset "Helper functions" begin X = randn(rng, 10, 10) Y = randn(rng, 10, 10) diff --git a/test/test_util.jl b/test/test_util.jl index 6f9adde54..a887535c2 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -179,8 +179,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm # The way we've structured the above, this tests that the rule is a DNERule @test x̄_ad isa DNE else - # TODO: remove extern from the line below, it is just there to make test output readign easier for now - @test isapprox(extern(x̄_ad), x̄_fd; rtol=rtol, atol=atol, kwargs...) + @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) end end @@ -251,6 +250,9 @@ test_store!(x̄::Number, dx, partial) = nothing function test_store!(x̄::AbstractArray, dx, partial) x̄_copy = copy(x̄) store!(x̄_copy, dx) - @test all(x̄_copy .≈ extern(partial)) + + @show extern(partial) + @show x̄_copy + @test x̄_copy ≈ extern(partial) return nothing end From c971052254ff5bf7056053820234f3fc648d7f14 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 11 Sep 2019 20:51:47 +0100 Subject: [PATCH 22/38] make factorizations accumulate! right --- src/helper_functions.jl | 22 +++--- src/rulesets/LinearAlgebra/blas.jl | 2 +- test/helper_functions.jl | 7 +- test/rulesets/Base/base.jl | 8 +-- test/rulesets/LinearAlgebra/factorization.jl | 11 +-- test/runtests.jl | 2 +- test/test_util.jl | 71 +++++++++----------- 7 files changed, 63 insertions(+), 60 deletions(-) diff --git a/src/helper_functions.jl b/src/helper_functions.jl index ddd371e60..341908128 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -1,8 +1,4 @@ -# Special purpose updating for operations which can be done in-place. This function is -# just internal and free-form; it is not a method of `accumulate!` directly as it does -# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`. -# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular -# rule. +# Internal helpers for defining the `add!` field of an `InplaceableThunk` _update!(x, y) = x + y _update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y @@ -11,18 +7,26 @@ _update!(x, ::Zero) = x _update!(::Zero, y) = y _update!(::Zero, ::Zero) = Zero() -function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns - return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns)) -end function _update!(x::NamedTuple, y, p::Symbol) - new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),)) + y = extern(y) + yp = getproperty(y, p) + xp = getproperty(x, p) + new_xp = _update!(xp, yp) + new = NamedTuple{(p,)}((new_xp,)) return merge(x, new) end +#== +function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns + return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns)) +end + + function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns return _update!(x, getproperty(y, p), p) end +==# """ _checked_rrule diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 4583a7066..bd8e42cb5 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -115,7 +115,7 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, ) else ∂A = InplaceableThunk( - @thunk(α * x * ȳ), + @thunk(α * x * ȳ'), Ā -> ger!(α, x, ȳ, Ā) ) ∂x = InplaceableThunk( diff --git a/test/helper_functions.jl b/test/helper_functions.jl index 7d3a8d170..062cb631a 100644 --- a/test/helper_functions.jl +++ b/test/helper_functions.jl @@ -19,10 +19,11 @@ end @testset "_update! NamedTuple" begin X = (A=[1 0; 0 1], B=[2 2; 2 2]) + old_X = deepcopy(X) Y = deepcopy(X) - @test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4]) - @test X.A != Y.A - @test X.B != Y.B + @test ChainRules._update!(X, Y, :A) == (A=[2 0; 0 2], B=[2 2; 2 2]) + @test X.A != old_X.A + @test X.B == old_X.B end @testset "_checked_rrule" begin try diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index d570e33f4..3e5e776f9 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -127,7 +127,7 @@ test_scalar(conj, x; rtol=rtol) end end - + @testset "*(x, y)" begin x, y = rand(3, 2), rand(2, 5) z, pullback = rrule(*, x, y) @@ -136,14 +136,14 @@ z̄ = rand(3, 5) (ds, dx, dy) = pullback(z̄) - + @test ds === NO_FIELDS @test extern(dx) == extern(accumulate(zeros(3, 2), dx)) @test extern(dy) == extern(accumulate(zeros(2, 5), dy)) - test_accumulation(rand(3, 2), dx, z̄ * y') - test_accumulation(rand(2, 5), dy, x' * z̄) + test_accumulation(rand(3, 2), dx) + test_accumulation(rand(2, 5), dy) end @testset "hypot(x, y)" begin diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index ec31389ae..5d4529dba 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -30,18 +30,21 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @testset "accumulate!" begin X = [1.0 2.0; 3.0 4.0; 5.0 6.0] - F, dX = rrule(svd, X) + F, dX_pullback = rrule(svd, X) X̄ = (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) for p in [:U, :S, :V] - Y, (dF, _) = rrule(getproperty, F, p) + Y, dF_pullback = rrule(getproperty, F, p) Ȳ = ones(size(Y)...) - ChainRules.accumulate!(X̄, dF, Ȳ) + (dself, dF, dp) = dF_pullback(Ȳ) + @test dself === NO_FIELDS + @test dp === DNE() + ChainRules.accumulate!(X̄, dF) end @test X̄.U ≈ ones(3, 2) atol=1e-6 @test X̄.S ≈ ones(2) atol=1e-6 @test X̄.V ≈ ones(2, 2) atol=1e-6 end - + @testset "Helper functions" begin X = randn(rng, 10, 10) Y = randn(rng, 10, 10) diff --git a/test/runtests.jl b/test/runtests.jl index 7512218e3..7379bb2f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,7 +12,7 @@ using Test # For testing purposes we use a lot of using ChainRulesCore: cast, extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, - Zero, One, Casted, DNE, Thunk, DNERule, AbstractDifferential + Zero, One, Casted, DNE, Thunk, AbstractDifferential include("test_util.jl") diff --git a/test/test_util.jl b/test/test_util.jl index a887535c2..6e09cbcf8 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -126,9 +126,8 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) # Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct. - # TODO is this test nonsense now? - test_accumulation(x̄, x̄_ad, x̄_ad) - test_accumulation(Zero(), x̄_ad, x̄_ad) + test_accumulation(x̄, x̄_ad) + test_accumulation(Zero(), x̄_ad) end function _make_fdm_call(fdm, f, ȳ, xs, ignores) @@ -186,8 +185,8 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm # Assuming the above to be correct, check that other ChainRules mechanisms are correct. for (x̄, x̄_ad) in zip(x̄s, x̄s_ad) x̄ === nothing && continue - test_accumulation(x̄, x̄_ad, x̄_ad) - test_accumulation(Zero(), x̄_ad, x̄_ad) + test_accumulation(x̄, x̄_ad) + test_accumulation(Zero(), x̄_ad) end end @@ -204,55 +203,51 @@ function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) return isapprox(extern(d_ad), d_fd; kwargs...) end -function test_accumulation(x̄, dx, partial) - @test all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) - test_accumulate(x̄, dx, partial) - test_accumulate!(x̄, dx, partial) - test_store!(x̄, dx, partial) - return nothing +function test_accumulation(x̄, ∂x) + @test all(extern(x̄ + ∂x) .≈ extern(x̄) .+ extern(∂x)) + test_accumulate(x̄, ∂x) + test_accumulate!(x̄, ∂x) + test_store!(x̄, ∂x) end -function test_accumulate(x̄::Zero, dx, partial) - @test extern(accumulate(x̄, dx)) ≈ extern(partial) - return nothing +function test_accumulate(x̄::Zero, ∂x) + @test extern(accumulate(x̄, ∂x)) ≈ extern(∂x) end -function test_accumulate(x̄::Number, dx, partial) - @test extern(accumulate(x̄, dx)) ≈ extern(x̄) + extern(partial) - return nothing +function test_accumulate(x̄::Number, ∂x) + @test extern(accumulate(x̄, ∂x)) ≈ extern(x̄) + extern(∂x) end -function test_accumulate(x̄::AbstractArray, dx, partial) +function test_accumulate(x̄::AbstractArray, ∂x) x̄_old = copy(x̄) - @test all(extern(accumulate(x̄, dx)) .≈ (extern(x̄) .+ extern(partial))) - @test x̄ == x̄_old - return nothing + @test all(extern(accumulate(x̄, ∂x)) .≈ (extern(x̄) .+ extern(∂x))) + @test x̄ == x̄_old # make sure didn't mutate x̄ end -test_accumulate!(x̄::Zero, dx, partial) = nothing +test_accumulate!(x̄::Zero, ∂x) = nothing -function test_accumulate!(x̄::Number, dx, partial) - @test accumulate!(x̄, dx) ≈ accumulate(x̄, dx) - return nothing +function test_accumulate!(x̄::Number, ∂x) + # This case won't have been inplace as `Number` is immutable + @test accumulate!(x̄, ∂x) ≈ accumulate(x̄, ∂x) end -function test_accumulate!(x̄::AbstractArray, dx, partial) +function test_accumulate!(x̄::AbstractArray, ∂x) x̄_copy = copy(x̄) - accumulate!(x̄_copy, dx) - @test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) - return nothing + accumulate!(x̄_copy, ∂x) # this should have actually been in-place + @test extern(x̄_copy) ≈ (extern(x̄) .+ extern(∂x)) end -test_store!(x̄::Zero, dx, partial) = nothing -test_store!(x̄::Number, dx, partial) = nothing +test_store!(x̄::Zero, ∂x) = nothing +test_store!(x̄::Number, ∂x) = nothing -function test_store!(x̄::AbstractArray, dx, partial) - x̄_copy = copy(x̄) - store!(x̄_copy, dx) +function test_store!(x̄::AbstractArray, ∂x) + x̄_store = copy(x̄) + store!(x̄_store, ∂x) + @test x̄_store ≈ extern(∂x) - @show extern(partial) - @show x̄_copy - @test x̄_copy ≈ extern(partial) - return nothing + # store! is the same as `accumulate!` to a zero array + x̄_acc = false.*x̄ + accumulate!(x̄_acc, ∂x) + @test x̄_acc ≈ x̄_store end From 89809263ec6d81a8db2b671fd6e3ddd773db46fd Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 18:57:45 +0100 Subject: [PATCH 23/38] style and typos Co-Authored-By: Nick Robinson --- src/rulesets/LinearAlgebra/dense.jl | 8 ++++---- src/rulesets/Statistics/statistics.jl | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 203fa0af2..f3dd32b04 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -129,18 +129,18 @@ end function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} Y = A \ B - function forwardslash_pullback(Ȳ) + function backslash_pullback(Ȳ) S = T.name.wrapper ∂A = @thunk S(-(A' \ Ȳ) * Y') ∂B = @thunk A' \ Ȳ return NO_FIELDS, ∂A, ∂B end - return Y, forwardslash_pullback + return Y, backslash_pullback end function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) Y = A \ B - function forwardslash_pullback(Ȳ) + function backslash_pullback(Ȳ) ∂A = @thunk begin B̄ = A' \ Ȳ Ā = -B̄ * Y' @@ -151,7 +151,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R ∂B = @thunk A' \ Ȳ return NO_FIELDS, ∂A, ∂B end - return Y, forwardslash_pullback + return Y, backslash_pullback end diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 8b96e6027..0c40fa36b 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -14,11 +14,11 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) function mean_pullback(ȳ) ∂x = Thunk() do _, ∂sum_x = sum_pullback(ȳ) - extern(∂sum_x)/n + extern(∂sum_x) / n end return (NO_FIELDS, ∂x) end - return y_sum/n, mean_pullback + return y_sum / n, mean_pullback end function rrule(::typeof(mean), f, x::AbstractArray{<:Real}) @@ -27,9 +27,9 @@ function rrule(::typeof(mean), f, x::AbstractArray{<:Real}) function mean_pullback(ȳ) ∂x = Thunk() do _, _, ∂sum_x = sum_pullback(ȳ) - extern(∂sum_x)/n + extern(∂sum_x) / n end return (NO_FIELDS, DNE(), ∂x) end - return y_sum/n, mean_pullback + return y_sum / n, mean_pullback end From 8c1e418a800ee5762be055ededa87c9dae4839c2 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 19:14:13 +0100 Subject: [PATCH 24/38] use _fdm rather than making a new central_fdm Co-Authored-By: Nick Robinson --- test/rulesets/LinearAlgebra/factorization.jl | 2 +- test/rulesets/Statistics/statistics.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 5d4529dba..0b29cec96 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -73,7 +73,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo ΔF = extern(dF) _, dX = dX_pullback(ΔF) X̄_ad = dot(extern(dX), V) - X̄_fd = central_fdm(5,1)() do ε + X̄_fd = _fdm() do ε dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) end @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 diff --git a/test/rulesets/Statistics/statistics.jl b/test/rulesets/Statistics/statistics.jl index 5ddafe0e7..ee2af7671 100644 --- a/test/rulesets/Statistics/statistics.jl +++ b/test/rulesets/Statistics/statistics.jl @@ -27,7 +27,7 @@ ȳ = randn(rng, size(y)) _, dX = mean_pullback(ȳ) X̄_ad = extern(dX) - X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X) + X̄_fd = j′vp(_fdm, x->mean(x, dims=1), ȳ, X) @test X̄_ad ≈ X̄_fd rtol=1e-9 atol=1e-9 end end From b5d3ba944e01735f25aaaf13b8079229f9b2ed1f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 14:12:33 +0100 Subject: [PATCH 25/38] set version correctly Co-Authored-By: Nick Robinson --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 84f3ba69f..e43431f5c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.3.0" +version = "0.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 9eee9db054e11d0a3973f46d5627923aa7a65846 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 16:57:03 +0100 Subject: [PATCH 26/38] name some pullbacks --- src/rulesets/Base/array.jl | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 3be8994c8..fa69b17c7 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -3,14 +3,18 @@ ##### function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) - return reshape(A, dims), Ȳ -> (NO_FIELDS, reshape(Ȳ, dims), DNE()) + function reshape_pullback(Ȳ) + return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DNE()) + end + return reshape(A, dims), reshape_pullback end function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) - return ( - reshape(A, dims...), - Ȳ -> (NO_FIELDS, reshape(Ȳ, dims), fill(DNE(), length(dims))...) - ) + function reshape_pullback(Ȳ) + ∂A = @thunk(reshape(Ȳ, dims)) + return (NO_FIELDS, ∂A, fill(DNE(), length(dims))...) + end + return reshape(A, dims...), reshape_pullback end ##### @@ -58,9 +62,15 @@ end ##### function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) - return fill(value, dims), Ȳ -> (NO_FIELDS, sum(Ȳ), DNE()) + function fill_pullback(Ȳ) + return (NO_FIELDS, @thunk(sum(Ȳ)), DNE()) + end + return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) - return fill(value, dims), Ȳ -> (NO_FIELDS, sum(Ȳ), ntuple(_->DNE(), length(dims))...) + function fill_pullback(Ȳ) + return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DNE(), length(dims))...) + end + return fill(value, dims), fill_pullback end From 1288748a6e2720bf8e135544ed1acd13087c78ba Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 17:10:57 +0100 Subject: [PATCH 27/38] name more propagators --- src/rulesets/Base/base.jl | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index fb5c473dd..6ce4635a5 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -103,8 +103,30 @@ # product rule requires special care for arguments where `mul` is non-commutative -frule(::typeof(*), x::Number, y::Number) = x * y, (_, Δx, Δy) -> Δx * y + x * Δy -rrule(::typeof(*), x::Number, y::Number) = x * y, (ΔΩ -> (NO_FIELDS, ΔΩ * y', x' * ΔΩ)) - -frule(::typeof(identity), x) = x, (_, ȳ) -> ȳ -rrule(::typeof(identity), x) = x, ȳ -> (NO_FIELDS, ȳ) +function frule(::typeof(*), x::Number, y::Number) + function times_pushforward(_, Δx, Δy) + return Δx * y + x * Δy + end + return x * y, times_pushforward +end + +function rrule(::typeof(*), x::Number, y::Number) + function times_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) + end + return x * y, times_pullback +end + +function frule(::typeof(identity), x) + function identity_pushforward(_, ẏ) + return ẏ + end + return x, identity_pushforward +end + +function rrule(::typeof(identity), x) + function identity_pullback(ȳ) + return (NO_FIELDS, ȳ) + end + return x, identity_pullback +end From f2e5705c72454fa5d004ab9545925c2de4814c2c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 17:23:19 +0100 Subject: [PATCH 28/38] More named propagators --- src/rulesets/Base/broadcast.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 8edb7fa5e..989e857f2 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -15,10 +15,16 @@ end function frule(::typeof(broadcast), f, x) Ω, ∂x = _cast_diff(f, x) - return Ω, (_, Δf, Δx) -> Δx * cast(∂x) + function broadcast_pushforward(_, Δf, Δx) + return Δx * cast(∂x) + end + return Ω, broadcast_pushforward end function rrule(::typeof(broadcast), f, x) values, derivs = _cast_diff(f, x) - return values, ΔΩ -> (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs))) + function broadcast_pullback(ΔΩ) + return (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs))) + end + return values, broadcast_pullback end From a566a0dcd7bfebe86ff260f2b6312b7a364eb825 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 17:51:15 +0100 Subject: [PATCH 29/38] Name more propagators --- src/rulesets/LinearAlgebra/dense.jl | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index f3dd32b04..93edeb154 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -49,13 +49,22 @@ end ##### function frule(::typeof(det), x) - Ω, m = det(x), @thunk(inv(x)) - return Ω, (_, Δx) -> Ω * tr(extern(m * Δx)) + Ω = det(x) + , @thunk(inv(x)) + function det_pushforward(_, ẋ) + # PERF-OPT: probably there is an efficent + # way to compute this trace without during the full compution within + + return Ω * tr(inv(x) * ẋ) + end + return Ω, det_pushforward end function rrule(::typeof(det), x) - Ω, m = det(x), @thunk(inv(x)') - return Ω, ΔΩ -> (NO_FIELDS, Ω * ΔΩ * m) + Ω = det(x) + function det_pullback(ΔΩ) + return NO_FIELDS, @thunk(Ω * ΔΩ * inv(x)') + return Ω, det_pullback end ##### From 85d0fa02eb811c2d2fde1e163f8fe8d70d7657b1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 18:07:03 +0100 Subject: [PATCH 30/38] More named propagators --- src/rulesets/LinearAlgebra/structured.jl | 103 +++++++++++++++++++---- 1 file changed, 86 insertions(+), 17 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 8de6cb0ac..c4e9a28d1 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -5,15 +5,30 @@ ##### `Diagonal` ##### -rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), ȳ->(NO_FIELDS, diag(ȳ)) - -rrule(::typeof(diag), A::AbstractMatrix) = diag(A), ȳ->(NO_FIELDS, Diagonal(ȳ)) +function rrule(::Type{<:Diagonal}, d::AbstractVector) + function Diagonal_pullback(ȳ) + return (NO_FIELDS, @thunk(diag(ȳ))) + end + return Diagonal(d), Diagonal_pullback +end + +function rrule(::typeof(diag), A::AbstractMatrix) + function diag_pullback(ȳ) + return (NO_FIELDS, @thunk(Diagonal(ȳ))) + end + return diag(A), diag_pullback +end ##### ##### `Symmetric` ##### -rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), ȳ->(NO_FIELDS, _symmetric_back(ȳ)) +function rrule(::Type{<:Symmetric}, A::AbstractMatrix) + function Symmetric_pullback(ȳ) + return (NO_FIELDS, @thunk(_symmetric_back(ȳ))) + end + return Symmetric(A), Symmetric_pullback +end _symmetric_back(ΔΩ) = @thunk(UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)) _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ @@ -23,26 +38,80 @@ _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### # ✖️✖️✖️TODO: Deal with complex-valued arrays as well -rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), ȳ->(NO_FIELDS, adjoint(ȳ)) -rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), ȳ->(NO_FIELDS, vec(adjoint(ȳ))) - -rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), ȳ->(NO_FIELDS, adjoint(ȳ)) -rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), ȳ->(NO_FIELDS, vec(adjoint(ȳ))) +function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) + function Adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(adjoint(ȳ))) + end + return Adjoint(A), Adjoint_pullback +end + +function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) + function Adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(vec(adjoint(ȳ)))) + end + return Adjoint(A), Adjoint_pullback +end + +function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) + function adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(adjoint(ȳ))) + end + return adjoint(A), +end + +function rrule(::typeof(adjoint), A::AbstractVector{<:Real}) + function adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(vec(adjoint(ȳ)))) + end + return adjoint(A), +end ##### ##### `Transpose` ##### -rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), ȳ->(NO_FIELDS, transpose(ȳ)) -rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), ȳ->(NO_FIELDS, vec(transpose(ȳ))) - -rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), ȳ->(NO_FIELDS, transpose(ȳ)) -rrule(::typeof(transpose), A::AbstractVector) = transpose(A), ȳ->(NO_FIELDS, vec(transpose(ȳ))) +function rrule(::Type{<:Transpose}, A::AbstractMatrix) + function Transpose_pullback(ȳ) + return (NO_FIELDS, @thunk transpose(ȳ)) + end + return Transpose(A), Transpose_pullback +end + +function rrule(::Type{<:Transpose}, A::AbstractVector) + function Transpose_pullback(ȳ) + return (NO_FIELDS, @thunk vec(transpose(ȳ))) + end + return Transpose(A), Transpose_pullback +end + +function rrule(::typeof(transpose), A::AbstractMatrix) + function transpose_pullback(ȳ) + return (NO_FIELDS, @thunk transpose(ȳ)) + end + return transpose(A), transpose_pullback +end + +function rrule(::typeof(transpose), A::AbstractVector) + function transpose_pullback(ȳ) + return (NO_FIELDS, @thunk vec(transpose(ȳ))) + end + return transpose(A), transpose_pullback +end ##### ##### Triangular matrices ##### -rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), ȳ->(NO_FIELDS, Matrix(ȳ)) - -rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), ȳ->(NO_FIELDS, Matrix(ȳ)) +function rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) + function UpperTriangular_pullback(ȳ) + return (NO_FIELDS, @thunk Matrix(ȳ)) + end + return UpperTriangular(A), UpperTriangular_pullback +end + +function rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) + function LowerTriangular_pullback(ȳ) + return (NO_FIELDS, @thunk Matrix(ȳ)) + end + return LowerTriangular(A), LowerTriangular_pullback +end From 0f9b3ac40ce97a9361fdc56d74a482e3d022ab88 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 18:08:30 +0100 Subject: [PATCH 31/38] delete extra unused _update! methods --- src/helper_functions.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/helper_functions.jl b/src/helper_functions.jl index 341908128..386e5d1f3 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -17,17 +17,6 @@ function _update!(x::NamedTuple, y, p::Symbol) return merge(x, new) end -#== -function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns - return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns)) -end - - -function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns - return _update!(x, getproperty(y, p), p) -end -==# - """ _checked_rrule From cc9028da57bc6a2550cfb9ecd674e1c4b3e8d6b7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 18:36:25 +0100 Subject: [PATCH 32/38] name more progators --- src/rulesets/LinearAlgebra/dense.jl | 38 ++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 93edeb154..6a47f4f21 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -50,11 +50,9 @@ end function frule(::typeof(det), x) Ω = det(x) - , @thunk(inv(x)) function det_pushforward(_, ẋ) # PERF-OPT: probably there is an efficent # way to compute this trace without during the full compution within - return Ω * tr(inv(x) * ẋ) end return Ω, det_pushforward @@ -64,6 +62,7 @@ function rrule(::typeof(det), x) Ω = det(x) function det_pullback(ΔΩ) return NO_FIELDS, @thunk(Ω * ΔΩ * inv(x)') + end return Ω, det_pullback end @@ -72,28 +71,49 @@ end ##### function frule(::typeof(logdet), x) - Ω, m = logdet(x), @thunk(inv(x)) - return Ω, (_, Δx) -> tr(extern(m * Δx)) + Ω = logdet(x) + function logdet_pushforward(_, Δx) + return tr(inv(x) * Δx) + end + return Ω, logdet_pushforward end function rrule(::typeof(logdet), x) - Ω, m = logdet(x), @thunk(inv(x)') - return Ω, ΔΩ -> (NO_FIELDS, ΔΩ * m) + Ω = logdet(x) + function logdet_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * inv(x)')) + end + return Ω, logdet_pullback end ##### ##### `trace` ##### -frule(::typeof(tr), x) = (tr(x), (_, Δx) -> tr(extern(Δx))) -rrule(::typeof(tr), x) = (tr(x), ΔΩ -> (NO_FIELDS, Diagonal(fill(ΔΩ, size(x, 1))))) +function frule(::typeof(tr), x) + function tr_pushforward(_, Δx) + return tr(Δx) + end + return tr(x), tr_pushforward +end + +function rrule(::typeof(tr), x) + function tr_pullback(ΔΩ) + return (NO_FIELDS, @thunk Diagonal(fill(ΔΩ, size(x, 1)))) + end + return tr(x), tr_pullback +end + ##### ##### `*` ##### function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) - return A * B, Ȳ -> (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) + function times_pullback(Ȳ) + return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) + end + return A * B, times_pullback end ##### From add271f41427d3eeb16c5a72eb96f23eca24350f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 18:51:31 +0100 Subject: [PATCH 33/38] fix up typos and extern new thunks in tests --- src/rulesets/LinearAlgebra/structured.jl | 4 ++-- test/rulesets/Base/array.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index c4e9a28d1..04ded78c8 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -56,14 +56,14 @@ function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) function adjoint_pullback(ȳ) return (NO_FIELDS, @thunk(adjoint(ȳ))) end - return adjoint(A), + return adjoint(A), adjoint_pullback end function rrule(::typeof(adjoint), A::AbstractVector{<:Real}) function adjoint_pullback(ȳ) return (NO_FIELDS, @thunk(vec(adjoint(ȳ)))) end - return adjoint(A), + return adjoint(A), adjoint_pullback end ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 4c7a00d5e..c113b6698 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -8,7 +8,7 @@ (s̄, Ā, d̄) = pullback(Ȳ) @test s̄ == NO_FIELDS @test d̄ isa DNE - @test Ā == reshape(Ȳ, (5, 4)) + @test extern(Ā) == reshape(Ȳ, (5, 4)) B, pullback = rrule(reshape, A, 5, 4) @test B == reshape(A, 5, 4) @@ -18,7 +18,7 @@ @test s̄ == NO_FIELDS @test d̄1 isa DNE @test d̄2 isa DNE - @test Ā == reshape(Ȳ, 5, 4) + @test extern(Ā) == reshape(Ȳ, 5, 4) end @testset "hcat" begin @@ -57,7 +57,7 @@ end (ds, dv, dd) = pullback(ones(4)) @test ds === NO_FIELDS @test dd isa DNE - @test dv == 4 + @test extern(dv) == 4 y, pullback = rrule(fill, 2.0, (3, 3, 3)) @test y == fill(2.0, (3, 3, 3)) From 3dbdbfc32072c1fc377740445db0fbc75737c88f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 18:56:12 +0100 Subject: [PATCH 34/38] test nonsquares --- test/rulesets/Base/mapreduce.jl | 2 +- test/rulesets/Statistics/statistics.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 5efc422d2..a8699663f 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -57,7 +57,7 @@ @testset "keyword arguments" begin rng = MersenneTwister(33) n = 4 - X = randn(rng, n, n) + X = randn(rng, n, n+1) y, pullback = rrule(sum, X; dims=2) ȳ = randn(rng, size(y)) _, x̄_ad = pullback(ȳ) diff --git a/test/rulesets/Statistics/statistics.jl b/test/rulesets/Statistics/statistics.jl index ee2af7671..c31af6f15 100644 --- a/test/rulesets/Statistics/statistics.jl +++ b/test/rulesets/Statistics/statistics.jl @@ -22,7 +22,7 @@ end @testset "with dims kwargs" begin - X = randn(rng, n, n) + X = randn(rng, n, n+1) y, mean_pullback = rrule(mean, X; dims=1) ȳ = randn(rng, size(y)) _, dX = mean_pullback(ȳ) From 6ca8deba6e196beca16d56095e650836160adf47 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 18 Sep 2019 19:01:30 +0100 Subject: [PATCH 35/38] more named propagators --- src/rulesets/LinearAlgebra/blas.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index bd8e42cb5..34702b1fd 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -71,11 +71,17 @@ end ##### function frule(::typeof(BLAS.asum), x) - return BLAS.asum(x), (_, Δx) -> sum(cast(sign, x) * Δx) + function asum_pushforward(_, Δx) + return sum(cast(sign, x) * Δx) + end + return BLAS.asum(x), asum_pushforward end function rrule(::typeof(BLAS.asum), x) - return BLAS.asum(x), ΔΩ -> (NO_FIELDS, @thunk(ΔΩ * cast(sign, x))) + function asum_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x))) + end + return BLAS.asum(x), asum_pullback end function rrule(::typeof(BLAS.asum), n, X, incx) From 320dba0a4b773957758ad50fb8bf26e8c54eeb9b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 19 Sep 2019 18:08:00 +0100 Subject: [PATCH 36/38] Apply suggestions from code review Co-Authored-By: Nick Robinson --- src/rulesets/Base/mapreduce.jl | 4 ++-- src/rulesets/LinearAlgebra/blas.jl | 5 +++-- src/rulesets/LinearAlgebra/dense.jl | 14 +++++++------- src/rulesets/LinearAlgebra/structured.jl | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 4d7dfa73e..a21b4ffcb 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -24,7 +24,7 @@ end ##### for mf in (:mapreduce, :mapfoldl, :mapfoldr) - sig = :(ChainRulesCore.rrule(::typeof($mf), f, op, x::AbstractArray{<:Real})) + sig = :(rrule(::typeof($mf), f, op, x::AbstractArray{<:Real})) call = :($mf(f, op, x)) if mf === :mapreduce insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:)))) @@ -73,7 +73,7 @@ end function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) y = sum(abs2, x; dims=dims) function sum_abs2_pullback(ȳ) - (NO_FIELDS, DNE(), @thunk(2ȳ .* x)) + return (NO_FIELDS, DNE(), @thunk(2ȳ .* x)) end return y, sum_abs2_pullback end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 34702b1fd..5772d2832 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -46,7 +46,7 @@ end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) function nrm2_pullback(ΔΩ) - return NO_FIELDS, @thunk(ΔΩ * @thunk(x * inv(Ω))) + return NO_FIELDS, @thunk(ΔΩ * x * inv(Ω)) end return Ω, nrm2_pullback end @@ -92,7 +92,8 @@ function rrule(::typeof(BLAS.asum), n, X, incx) else ΔΩ = extern(ΔΩ) ∂X = @thunk scal!( - n, ΔΩ, + n, + ΔΩ, blascopy!(n, sign.(X), incx, _zeros(X), incx), incx ) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 6a47f4f21..861cce8d8 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -10,14 +10,14 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} function frule(::typeof(dot), x, y) function dot_pushforward(Δself, Δx, Δy) - sum(Δx * cast(y)) + sum(cast(x) * Δy) + return sum(Δx * cast(y)) + sum(cast(x) * Δy) end return dot(x, y), dot_pushforward end function rrule(::typeof(dot), x, y) function dot_pullback(ΔΩ) - (NO_FIELDS, ΔΩ * cast(y), cast(x) * ΔΩ,) + return (NO_FIELDS, ΔΩ * cast(y), cast(x) * ΔΩ,) end return dot(x, y), dot_pullback end @@ -30,7 +30,7 @@ function frule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω) function inv_pushforward(_, Δx) - m * Δx * Ω + return m * Δx * Ω end return Ω, inv_pushforward end @@ -39,7 +39,7 @@ function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω') function inv_pullback(ΔΩ) - NO_FIELDS, m * ΔΩ * Ω' + return NO_FIELDS, m * ΔΩ * Ω' end return Ω, inv_pullback end @@ -51,7 +51,7 @@ end function frule(::typeof(det), x) Ω = det(x) function det_pushforward(_, ẋ) - # PERF-OPT: probably there is an efficent + # TODO Performance optimization: probably there is an efficent # way to compute this trace without during the full compution within return Ω * tr(inv(x) * ẋ) end @@ -126,7 +126,7 @@ function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatr S = T.name.wrapper ∂A = @thunk Ȳ / B' ∂B = @thunk S(-Y' * (Ȳ / B')) - (NO_FIELDS, ∂A, ∂B) + return (NO_FIELDS, ∂A, ∂B) end return Y, slash_pullback end @@ -202,7 +202,7 @@ end function rrule(::typeof(norm), x::Real, p::Real=2) function norm_pullback(ȳ) ∂x = @thunk ȳ * sign(x) - ∂p = @thunk zero(x) #TODO: should this be Zero()? + ∂p = @thunk zero(x) # TODO: should this be Zero()? (NO_FIELDS, ∂x, ∂p) end return norm(x, p), norm_pullback diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 04ded78c8..3e8934afc 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -30,7 +30,7 @@ function rrule(::Type{<:Symmetric}, A::AbstractMatrix) return Symmetric(A), Symmetric_pullback end -_symmetric_back(ΔΩ) = @thunk(UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)) +_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### From c3c20f6b2300ce176b8681d8d07add2eff36c03e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 19 Sep 2019 18:33:40 +0100 Subject: [PATCH 37/38] more named propagators --- src/rulesets/Base/mapreduce.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index a21b4ffcb..bc1ea223a 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -50,9 +50,19 @@ end ##### `sum` ##### -frule(::typeof(sum), x) = (sum(x), (_, ẋ)->sum(ẋ)) +function frule(::typeof(sum), x) + function sum_pushforward(_, ẋ) + return sum(ẋ) + end + return sum(x), sum_pushforward +end -rrule(::typeof(sum), x) = (sum(x), ȳ->(NO_FIELDS, cast(ȳ))) +function rrule(::typeof(sum), x) + function sum_pullback(ȳ) + return (NO_FIELDS, cast(ȳ)) + end + return sum(x), sum_pullback +end function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims) From bcb24db948151e438073f3f346180e764eaa4980 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 19 Sep 2019 22:24:43 +0100 Subject: [PATCH 38/38] Update test/rulesets/Base/base.jl Co-Authored-By: Nick Robinson --- test/rulesets/Base/base.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 3e5e776f9..c8fd73827 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -55,7 +55,7 @@ x, y = rand(2) @testset "atan2" begin # https://en.wikipedia.org/wiki/Atan2 - ratan = atan(x, y) # https://en.wikipedia.org/wiki/Atan2 + ratan = atan(x, y) u = x^2 + y^2 datan = y/u - 2x/u