diff --git a/Project.toml b/Project.toml index acc564c2d..e43431f5c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.1.1" +version = "0.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "^0.2" +ChainRulesCore = "^0.3" FiniteDifferences = "^0.7" julia = "^1.0" diff --git a/src/helper_functions.jl b/src/helper_functions.jl index 392ae2de6..386e5d1f3 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -1,8 +1,4 @@ -# Special purpose updating for operations which can be done in-place. This function is -# just internal and free-form; it is not a method of `accumulate!` directly as it does -# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`. -# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular -# rule. +# Internal helpers for defining the `add!` field of an `InplaceableThunk` _update!(x, y) = x + y _update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y @@ -11,20 +7,22 @@ _update!(x, ::Zero) = x _update!(::Zero, y) = y _update!(::Zero, ::Zero) = Zero() -function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns - return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns)) -end function _update!(x::NamedTuple, y, p::Symbol) - new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),)) + y = extern(y) + yp = getproperty(y, p) + xp = getproperty(x, p) + new_xp = _update!(xp, yp) + new = NamedTuple{(p,)}((new_xp,)) return merge(x, new) end -function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns - return _update!(x, getproperty(y, p), p) -end - +""" + _checked_rrule +like `rrule` but throws an error if the `rrule` is not defined. +Rather than returning `nothing` +""" function _checked_rrule(f, args...; kwargs...) r = rrule(f, args...; kwargs...) r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index b071f0447..fa69b17c7 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -3,12 +3,18 @@ ##### function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) - return reshape(A, dims), (Rule(Ȳ->reshape(Ȳ, dims)), DNERule()) + function reshape_pullback(Ȳ) + return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DNE()) + end + return reshape(A, dims), reshape_pullback end function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) - Y, (rule, _) = rrule(reshape, A, dims) - return Y, (rule, fill(DNERule(), length(dims))...) + function reshape_pullback(Ȳ) + ∂A = @thunk(reshape(Ȳ, dims)) + return (NO_FIELDS, ∂A, fill(DNE(), length(dims))...) + end + return reshape(A, dims...), reshape_pullback end ##### @@ -16,17 +22,21 @@ end ##### function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) - Y = hcat(A, Bs...) - Xs = (A, Bs...) - rules = ntuple(length(Bs) + 1) do i - l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) - u = l + size(Xs[i], 2) - dim = u > l + 1 ? (l+1:u) : u - # NOTE: The copy here is defensive, since `selectdim` returns a view which we can - # materialize with `copy` - Rule(Ȳ->copy(selectdim(Ȳ, 2, dim))) + function hcat_pullback(Ȳ) + Xs = (A, Bs...) + ntuple(length(Bs) + 2) do full_i + full_i == 1 && return NO_FIELDS + + i = full_i - 1 + l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) + u = l + size(Xs[i], 2) + dim = u > l + 1 ? (l+1:u) : u + # NOTE: The copy here is defensive, since `selectdim` returns a view which we can + # materialize with `copy` + copy(selectdim(Ȳ, 2, dim)) + end end - return Y, rules + return hcat(A, Bs...), hcat_pullback end ##### @@ -34,15 +44,17 @@ end ##### function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) - Y = vcat(A, Bs...) - n = size(A, 1) - ∂A = Rule(Ȳ->copy(selectdim(Ȳ, 1, 1:n))) - ∂Bs = ntuple(length(Bs)) do i - l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0) - u = l + size(Bs[i], 1) - Rule(Ȳ->copy(selectdim(Ȳ, 1, l+1:u))) + function vcat_pullback(Ȳ) + n = size(A, 1) + ∂A = copy(selectdim(Ȳ, 1, 1:n)) + ∂Bs = ntuple(length(Bs)) do i + l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0) + u = l + size(Bs[i], 1) + copy(selectdim(Ȳ, 1, l+1:u)) + end + return (NO_FIELDS, ∂A, ∂Bs...) end - return Y, (∂A, ∂Bs...) + return vcat(A, Bs...), vcat_pullback end ##### @@ -50,9 +62,15 @@ end ##### function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) - return fill(value, dims), (Rule(sum), DNERule()) + function fill_pullback(Ȳ) + return (NO_FIELDS, @thunk(sum(Ȳ)), DNE()) + end + return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) - return fill(value, dims), (Rule(sum), ntuple(_->DNERule(), length(dims))...) + function fill_pullback(Ȳ) + return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DNE(), length(dims))...) + end + return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index b6f67f5d6..6ce4635a5 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -103,10 +103,30 @@ # product rule requires special care for arguments where `mul` is non-commutative -frule(::typeof(*), x::Number, y::Number) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy) - -rrule(::typeof(*), x::Number, y::Number) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) - -frule(::typeof(identity), x) = x, Rule(identity) - -rrule(::typeof(identity), x) = x, Rule(identity) +function frule(::typeof(*), x::Number, y::Number) + function times_pushforward(_, Δx, Δy) + return Δx * y + x * Δy + end + return x * y, times_pushforward +end + +function rrule(::typeof(*), x::Number, y::Number) + function times_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) + end + return x * y, times_pullback +end + +function frule(::typeof(identity), x) + function identity_pushforward(_, ẏ) + return ẏ + end + return x, identity_pushforward +end + +function rrule(::typeof(identity), x) + function identity_pullback(ȳ) + return (NO_FIELDS, ȳ) + end + return x, identity_pullback +end diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index f3685f5ca..989e857f2 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -5,9 +5,9 @@ without relying on inference hacks unless we have something akin to https://github.com/JuliaLang/julia/issues/22129. =# function _cast_diff(f, x) - element_rule = u -> begin + function element_rule(u) fu, du = frule(f, u) - fu, extern(du(One())) + fu, extern(du(NamedTuple(), One())) end results = broadcast(element_rule, x) return first.(results), last.(results) @@ -15,10 +15,16 @@ end function frule(::typeof(broadcast), f, x) Ω, ∂x = _cast_diff(f, x) - return Ω, Rule((_, Δx) -> Δx * cast(∂x)) + function broadcast_pushforward(_, Δf, Δx) + return Δx * cast(∂x) + end + return Ω, broadcast_pushforward end function rrule(::typeof(broadcast), f, x) values, derivs = _cast_diff(f, x) - return values, (DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs))) + function broadcast_pullback(ΔΩ) + return (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs))) + end + return values, broadcast_pullback end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index b0e6a006b..bc1ea223a 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -4,15 +4,19 @@ function rrule(::typeof(map), f, xs...) y = map(f, xs...) - ∂xs = ntuple(length(xs)) do i - Rule() do ȳ - map(ȳ, xs...) do ȳi, xis... - _, ∂xis = _checked_rrule(f, xis...) - extern(∂xis[i](ȳi)) + function map_pullback(ȳ) + ntuple(length(xs)+2) do full_i + full_i == 1 && return NO_FIELDS + full_i == 2 && return DNE() + i = full_i-2 + @thunk map(ȳ, xs...) do ȳi, xis... + _, pullback = _checked_rrule(f, xis...) + ∂xis = pullback(ȳi) + extern(∂xis[i+1]) #+1 to skp ∂self end end end - return y, (DNERule(), ∂xs...) + return y, map_pullback end ##### @@ -26,15 +30,18 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr) insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:)))) insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims))) end + pullback_name = Symbol(mf, :_pullback) body = quote y = $call - ∂x = Rule() do ȳ - broadcast(x, ȳ) do xi, ȳi - _, ∂xi = _checked_rrule(f, xi) - extern(∂xi(ȳi)) + function $pullback_name(ȳ) + ∂x = @thunk broadcast(x, ȳ) do xi, ȳi + _, pullback_f = _checked_rrule(f, xi) + _, ∂xi = pullback_f(ȳi) + extern(∂xi) end + (NO_FIELDS, DNE(), DNE(), ∂x) end - return y, (DNERule(), DNERule(), ∂x) + return y, $pullback_name end eval(Expr(:function, sig, body)) end @@ -43,22 +50,40 @@ end ##### `sum` ##### -frule(::typeof(sum), x) = (sum(x), Rule(sum)) +function frule(::typeof(sum), x) + function sum_pushforward(_, ẋ) + return sum(ẋ) + end + return sum(x), sum_pushforward +end -rrule(::typeof(sum), x) = (sum(x), Rule(cast)) +function rrule(::typeof(sum), x) + function sum_pullback(ȳ) + return (NO_FIELDS, cast(ȳ)) + end + return sum(x), sum_pullback +end function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) - y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims) - return y, (DNERule(), ∂x) + y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims) + function sum_pullback(ȳ) + NO_FIELDS, DNE(), last(mr_pullback(ȳ)) + end + return y, sum_pullback end function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:) - y, (_, ∂x) = rrule(sum, identity, x; dims=dims) - return y, ∂x + y, inner_pullback = rrule(sum, identity, x; dims=dims) + function sum_pullback(ȳ) + NO_FIELDS, last(inner_pullback(ȳ)) + end + return y, sum_pullback end function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) y = sum(abs2, x; dims=dims) - ∂x = Rule(ȳ -> 2ȳ .* x) - return y, (DNERule(), ∂x) + function sum_abs2_pullback(ȳ) + return (NO_FIELDS, DNE(), @thunk(2ȳ .* x)) + end + return y, sum_abs2_pullback end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index fb5b23f4b..5772d2832 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -7,8 +7,6 @@ using LinearAlgebra: BlasFloat _zeros(x) = fill!(similar(x), zero(eltype(x))) -_rule_via(∂) = Rule(ΔΩ -> isa(ΔΩ, Zero) ? ΔΩ : ∂(extern(ΔΩ))) - ##### ##### `BLAS.dot` ##### @@ -19,9 +17,18 @@ rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y) function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) Ω = BLAS.dot(n, X, incx, Y, incy) - ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) - ∂Y = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) - return Ω, (DNERule(), _rule_via(∂X), DNERule(), _rule_via(∂Y), DNERule()) + function blas_dot_pullback(ΔΩ) + if ΔΩ isa Zero + ∂X = Zero() + ∂Y = Zero() + else + ΔΩ = extern(ΔΩ) + ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) + ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) + end + return (NO_FIELDS, DNE(), ∂X, DNE(), ∂Y, DNE()) + end + return Ω, blas_dot_pullback end ##### @@ -30,32 +37,70 @@ end function frule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) - return Ω, Rule(Δx -> sum(Δx * cast(@thunk(x * inv(Ω))))) + function nrm2_pushforward(_, Δx) + return sum(Δx * cast(@thunk(x * inv(Ω)))) + end + return Ω, nrm2_pushforward end function rrule(::typeof(BLAS.nrm2), x) Ω = BLAS.nrm2(x) - return Ω, Rule(ΔΩ -> ΔΩ * @thunk(x * inv(Ω))) + function nrm2_pullback(ΔΩ) + return NO_FIELDS, @thunk(ΔΩ * x * inv(Ω)) + end + return Ω, nrm2_pullback end function rrule(::typeof(BLAS.nrm2), n, X, incx) Ω = BLAS.nrm2(n, X, incx) - ∂X = ΔΩ -> scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) - return Ω, (DNERule(), _rule_via(∂X), DNERule()) + function nrm2_pullback(ΔΩ) + if ΔΩ isa Zero + ∂X = Zero() + else + ΔΩ = extern(ΔΩ) + ∂X = scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) + end + return (NO_FIELDS, DNE(), ∂X, DNE()) + end + + return Ω, nrm2_pullback end ##### ##### `BLAS.asum` ##### -frule(::typeof(BLAS.asum), x) = (BLAS.asum(x), Rule(Δx -> sum(cast(sign, x) * Δx))) +function frule(::typeof(BLAS.asum), x) + function asum_pushforward(_, Δx) + return sum(cast(sign, x) * Δx) + end + return BLAS.asum(x), asum_pushforward +end -rrule(::typeof(BLAS.asum), x) = (BLAS.asum(x), Rule(ΔΩ -> ΔΩ * cast(sign, x))) +function rrule(::typeof(BLAS.asum), x) + function asum_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x))) + end + return BLAS.asum(x), asum_pullback +end function rrule(::typeof(BLAS.asum), n, X, incx) Ω = BLAS.asum(n, X, incx) - ∂X = ΔΩ -> scal!(n, ΔΩ, blascopy!(n, sign.(X), incx, _zeros(X), incx), incx) - return Ω, (DNERule(), _rule_via(∂X), DNERule()) + function asum_pullback(ΔΩ) + if ΔΩ isa Zero + ∂X = Zero() + else + ΔΩ = extern(ΔΩ) + ∂X = @thunk scal!( + n, + ΔΩ, + blascopy!(n, sign.(X), incx, _zeros(X), incx), + incx + ) + end + return (NO_FIELDS, DNE(), ∂X, DNE()) + end + return Ω, asum_pullback end ##### @@ -65,20 +110,39 @@ end function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat y = gemv(tA, α, A, x) - if uppercase(tA) === 'N' - ∂A = Rule(ȳ -> α * ȳ * x', (Ā, ȳ) -> ger!(α, ȳ, x, Ā)) - ∂x = Rule(ȳ -> gemv('T', α, A, ȳ), (x̄, ȳ) -> gemv!('T', α, A, ȳ, one(T), x̄)) - else - ∂A = Rule(ȳ -> α * x * ȳ', (Ā, ȳ) -> ger!(α, x, ȳ, Ā)) - ∂x = Rule(ȳ -> gemv('N', α, A, ȳ), (x̄, ȳ) -> gemv!('N', α, A, ȳ, one(T), x̄)) + function gemv_pullback(ȳ) + if uppercase(tA) === 'N' + ∂A = InplaceableThunk( + @thunk(α * ȳ * x'), + Ā -> ger!(α, ȳ, x, Ā) + ) + ∂x = InplaceableThunk( + @thunk(gemv('T', α, A, ȳ)), + x̄ -> gemv!('T', α, A, ȳ, one(T), x̄) + ) + else + ∂A = InplaceableThunk( + @thunk(α * x * ȳ'), + Ā -> ger!(α, x, ȳ, Ā) + ) + ∂x = InplaceableThunk( + @thunk(gemv('N', α, A, ȳ)), + x̄ -> gemv!('N', α, A, ȳ, one(T), x̄) + ) + end + return (NO_FIELDS, DNE(), @thunk(dot(ȳ, y) / α), ∂A, ∂x) end - return y, (DNERule(), Rule(ȳ -> dot(ȳ, y) / α), ∂A, ∂x) + return y, gemv_pullback end function rrule(::typeof(gemv), tA::Char, A::AbstractMatrix{T}, x::AbstractVector{T}) where T<:BlasFloat - y, (dtA, _, dA, dx) = rrule(gemv, tA, one(T), A, x) - return y, (dtA, dA, dx) + y, inner_pullback = rrule(gemv, tA, one(T), A, x) + function gemv_pullback(Ȳ) + (_, dtA, _, dA, dx) = inner_pullback(Ȳ) + return (NO_FIELDS, dtA, dA, dx) + end + return y, gemv_pullback end ##### @@ -88,37 +152,60 @@ end function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat C = gemm(tA, tB, α, A, B) - β = one(T) - if uppercase(tA) === 'N' - if uppercase(tB) === 'N' - ∂A = Rule(C̄ -> gemm('N', 'T', α, C̄, B), - (Ā, C̄) -> gemm!('N', 'T', α, C̄, B, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'N', α, A, C̄), - (B̄, C̄) -> gemm!('T', 'N', α, A, C̄, β, B̄)) + function gemv_pullback(C̄) + β = one(T) + if uppercase(tA) === 'N' + if uppercase(tB) === 'N' + ∂A = InplaceableThunk( + @thunk(gemm('N', 'T', α, C̄, B)), + Ā -> gemm!('N', 'T', α, C̄, B, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('T', 'N', α, A, C̄)), + B̄ -> gemm!('T', 'N', α, A, C̄, β, B̄) + ) + else + ∂A = InplaceableThunk( + @thunk(gemm('N', 'N', α, C̄, B)), + Ā -> gemm!('N', 'N', α, C̄, B, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('T', 'N', α, C̄, A)), + B̄ -> gemm!('T', 'N', α, C̄, A, β, B̄) + ) + end else - ∂A = Rule(C̄ -> gemm('N', 'N', α, C̄, B), - (Ā, C̄) -> gemm!('N', 'N', α, C̄, B, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'N', α, C̄, A), - (B̄, C̄) -> gemm!('T', 'N', α, C̄, A, β, B̄)) - end - else - if uppercase(tB) === 'N' - ∂A = Rule(C̄ -> gemm('N', 'T', α, B, C̄), - (Ā, C̄) -> gemm!('N', 'T', α, B, C̄, β, Ā)) - ∂B = Rule(C̄ -> gemm('N', 'N', α, A, C̄), - (B̄, C̄) -> gemm!('N', 'N', α, A, C̄, β, B̄)) - else - ∂A = Rule(C̄ -> gemm('T', 'T', α, B, C̄), - (Ā, C̄) -> gemm!('T', 'T', α, B, C̄, β, Ā)) - ∂B = Rule(C̄ -> gemm('T', 'T', α, C̄, A), - (B̄, C̄) -> gemm!('T', 'T', α, C̄, A, β, B̄)) + if uppercase(tB) === 'N' + ∂A = InplaceableThunk( + @thunk(gemm('N', 'T', α, B, C̄)), + Ā -> gemm!('N', 'T', α, B, C̄, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('N', 'N', α, A, C̄)), + B̄ -> gemm!('N', 'N', α, A, C̄, β, B̄) + ) + else + ∂A = InplaceableThunk( + @thunk(gemm('T', 'T', α, B, C̄)), + Ā -> gemm!('T', 'T', α, B, C̄, β, Ā) + ) + ∂B = InplaceableThunk( + @thunk(gemm('T', 'T', α, C̄, A)), + B̄ -> gemm!('T', 'T', α, C̄, A, β, B̄) + ) + end end + return (NO_FIELDS, DNE(), DNE(), @thunk(dot(C̄, C) / α), ∂A, ∂B) end - return C, (DNERule(), DNERule(), Rule(C̄ -> dot(C̄, C) / α), ∂A, ∂B) + return C, gemv_pullback end function rrule(::typeof(gemm), tA::Char, tB::Char, A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat - C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B) - return C, (dtA, dtB, dA, dB) + C, inner_pullback = rrule(gemm, tA, tB, one(T), A, B) + function gemv_pullback(Ȳ) + (_, dtA, dtB, _, dA, dB) = inner_pullback(Ȳ) + return (NO_FIELDS, dtA, dtB, dA, dB) + end + return C, gemm_pullback end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 9eb3ee168..861cce8d8 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -9,11 +9,17 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}} ##### function frule(::typeof(dot), x, y) - return dot(x, y), Rule((Δx, Δy) -> sum(Δx * cast(y)) + sum(cast(x) * Δy)) + function dot_pushforward(Δself, Δx, Δy) + return sum(Δx * cast(y)) + sum(cast(x) * Δy) + end + return dot(x, y), dot_pushforward end function rrule(::typeof(dot), x, y) - return dot(x, y), (Rule(ΔΩ -> ΔΩ * cast(y)), Rule(ΔΩ -> cast(x) * ΔΩ)) + function dot_pullback(ΔΩ) + return (NO_FIELDS, ΔΩ * cast(y), cast(x) * ΔΩ,) + end + return dot(x, y), dot_pullback end ##### @@ -23,13 +29,19 @@ end function frule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω) - return Ω, Rule(Δx -> m * Δx * Ω) + function inv_pushforward(_, Δx) + return m * Δx * Ω + end + return Ω, inv_pushforward end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) m = @thunk(-Ω') - return Ω, Rule(ΔΩ -> m * ΔΩ * Ω') + function inv_pullback(ΔΩ) + return NO_FIELDS, m * ΔΩ * Ω' + end + return Ω, inv_pullback end ##### @@ -37,13 +49,21 @@ end ##### function frule(::typeof(det), x) - Ω, m = det(x), @thunk(inv(x)) - return Ω, Rule(Δx -> Ω * tr(extern(m * Δx))) + Ω = det(x) + function det_pushforward(_, ẋ) + # TODO Performance optimization: probably there is an efficent + # way to compute this trace without during the full compution within + return Ω * tr(inv(x) * ẋ) + end + return Ω, det_pushforward end function rrule(::typeof(det), x) - Ω, m = det(x), @thunk(inv(x)') - return Ω, Rule(ΔΩ -> Ω * ΔΩ * m) + Ω = det(x) + function det_pullback(ΔΩ) + return NO_FIELDS, @thunk(Ω * ΔΩ * inv(x)') + end + return Ω, det_pullback end ##### @@ -51,29 +71,49 @@ end ##### function frule(::typeof(logdet), x) - Ω, m = logdet(x), @thunk(inv(x)) - return Ω, Rule(Δx -> tr(extern(m * Δx))) + Ω = logdet(x) + function logdet_pushforward(_, Δx) + return tr(inv(x) * Δx) + end + return Ω, logdet_pushforward end function rrule(::typeof(logdet), x) - Ω, m = logdet(x), @thunk(inv(x)') - return Ω, Rule(ΔΩ -> ΔΩ * m) + Ω = logdet(x) + function logdet_pullback(ΔΩ) + return (NO_FIELDS, @thunk(ΔΩ * inv(x)')) + end + return Ω, logdet_pullback end ##### ##### `trace` ##### -frule(::typeof(tr), x) = (tr(x), Rule(Δx -> tr(extern(Δx)))) +function frule(::typeof(tr), x) + function tr_pushforward(_, Δx) + return tr(Δx) + end + return tr(x), tr_pushforward +end + +function rrule(::typeof(tr), x) + function tr_pullback(ΔΩ) + return (NO_FIELDS, @thunk Diagonal(fill(ΔΩ, size(x, 1)))) + end + return tr(x), tr_pullback +end -rrule(::typeof(tr), x) = (tr(x), Rule(ΔΩ -> Diagonal(fill(ΔΩ, size(x, 1))))) ##### ##### `*` ##### function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}) - return A * B, (Rule(Ȳ -> Ȳ * B'), Rule(Ȳ -> A' * Ȳ)) + function times_pullback(Ȳ) + return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ)) + end + return A * B, times_pullback end ##### @@ -82,20 +122,34 @@ end function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} Y = A / B - S = T.name.wrapper - ∂A = Rule(Ȳ -> Ȳ / B') - ∂B = Rule(Ȳ -> S(-Y' * (Ȳ / B'))) - return Y, (∂A, ∂B) + function slash_pullback(Ȳ) + S = T.name.wrapper + ∂A = @thunk Ȳ / B' + ∂B = @thunk S(-Y' * (Ȳ / B')) + return (NO_FIELDS, ∂A, ∂B) + end + return Y, slash_pullback end function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) - Aᵀ, dA = rrule(adjoint, A) - Bᵀ, dB = rrule(adjoint, B) - Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ) - C, dC = rrule(adjoint, Cᵀ) - ∂A = Rule(dA∘dAᵀ∘dC) - ∂B = Rule(dA∘dBᵀ∘dC) - return C, (∂A, ∂B) + Aᵀ, dA_pb = rrule(adjoint, A) + Bᵀ, dB_pb = rrule(adjoint, B) + Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) + C, dC_pb = rrule(adjoint, Cᵀ) + function slash_pullback(Ȳ) + # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want + # this is not a problem if you want the 2nd or 3rd, but if you want the first, it + # is fairly wasteful + _, dC = dC_pb(Ȳ) + _, dBᵀ, dAᵀ = dS_pb(extern(dC)) + + # need to extern as dAᵀ, dBᵀ are generally `Thunk`s, which don't support adjoint + ∂A = @thunk last(dA_pb(extern(dAᵀ))) + ∂B = @thunk last(dA_pb(extern(dBᵀ))) + + (NO_FIELDS, ∂A, ∂B) + end + return C, slash_pullback end ##### @@ -104,23 +158,30 @@ end function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} Y = A \ B - S = T.name.wrapper - ∂A = Rule(Ȳ -> S(-(A' \ Ȳ) * Y')) - ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (∂A, ∂B) + function backslash_pullback(Ȳ) + S = T.name.wrapper + ∂A = @thunk S(-(A' \ Ȳ) * Y') + ∂B = @thunk A' \ Ȳ + return NO_FIELDS, ∂A, ∂B + end + return Y, backslash_pullback end function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) Y = A \ B - ∂A = Rule() do Ȳ - B̄ = A' \ Ȳ - Ā = -B̄ * Y' - _add!(Ā, (B - A * Y) * B̄' / A') - _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) - Ā + function backslash_pullback(Ȳ) + ∂A = @thunk begin + B̄ = A' \ Ȳ + Ā = -B̄ * Y' + _add!(Ā, (B - A * Y) * B̄' / A') + _add!(Ā, A' \ Y * (Ȳ' - B̄'A)) + Ā + end + ∂B = @thunk A' \ Ȳ + return NO_FIELDS, ∂A, ∂B end - ∂B = Rule(Ȳ -> A' \ Ȳ) - return Y, (∂A, ∂B) + return Y, backslash_pullback + end ##### @@ -129,12 +190,20 @@ end function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2) y = norm(A, p) - u = y^(1-p) - ∂A = Rule(ȳ -> ȳ .* u .* abs.(A).^p ./ A) - ∂p = Rule(ȳ -> ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p) - return y, (∂A, ∂p) + function norm_pullback(ȳ) + u = y^(1-p) + ∂A = @thunk ȳ .* u .* abs.(A).^p ./ A + ∂p = @thunk ȳ * (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p + (NO_FIELDS, ∂A, ∂p) + end + return y, norm_pullback end function rrule(::typeof(norm), x::Real, p::Real=2) - return norm(x, p), (Rule(ȳ -> ȳ * sign(x)), Rule(_ -> zero(x))) + function norm_pullback(ȳ) + ∂x = @thunk ȳ * sign(x) + ∂p = @thunk zero(x) # TODO: should this be Zero()? + (NO_FIELDS, ∂x, ∂p) + end + return norm(x, p), norm_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 72527fcc6..6fc69ebee 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -7,25 +7,31 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) F = svd(X) - ∂X = Rule() do Ȳ::NamedTuple{(:U,:S,:V)} - svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) + function svd_pullback(Ȳ::NamedTuple{(:U,:S,:V)}) + ∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)) + return (NO_FIELDS, ∂X) end - return F, ∂X + return F, svd_pullback end function rrule(::typeof(getproperty), F::SVD, x::Symbol) - if x === :U - rule = Ȳ->(U=Ȳ, S=zero(F.S), V=zero(F.V)) - elseif x === :S - rule = Ȳ->(U=zero(F.U), S=Ȳ, V=zero(F.V)) - elseif x === :V - rule = Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ) - elseif x === :Vt - # TODO: This could be made to work, but it'd be a pain - throw(ArgumentError("Vt is unsupported; use V and transpose the result")) + function getproperty_svd_pullback(Ȳ) + if x === :U + ∂ = @thunk((; U=Ȳ, S=(zero(F.S)), V=(zero(F.V)))) + elseif x === :S + ∂ = @thunk((; U=(zero(F.U)), S=Ȳ, V=(zero(F.V)))) + elseif x === :V + ∂ = @thunk((; U=(zero(F.U)), S=(zero(F.S)), V=Ȳ)) + elseif x === :Vt + # TODO: This could be made to work, but it'd be a pain + throw(ArgumentError("Vt is unsupported; use V and transpose the result")) + end + + update = (X̄::NamedTuple{(:U,:S,:V)}) -> _update!(X̄, ∂, x) + ∂F = InplaceableThunk(∂, update) + return NO_FIELDS, ∂F, DNE() end - update = (X̄::NamedTuple{(:U,:S,:V)}, Ȳ)->_update!(X̄, rule(Ȳ), x) - return getproperty(F, x), (Rule(rule, update), DNERule()) + return getproperty(F, x), getproperty_svd_pullback end function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix) @@ -65,25 +71,31 @@ end function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) F = cholesky(X) - ∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) - return F, ∂X + function cholesky_pullback(Ȳ) + ∂X = @thunk(chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) + return (NO_FIELDS, ∂X) + end + return F, cholesky_pullback end function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) - if x === :U - if F.uplo === 'U' - ∂F = Ȳ->UpperTriangular(Ȳ) - else - ∂F = Ȳ->LowerTriangular(Ȳ') - end - elseif x === :L - if F.uplo === 'L' - ∂F = Ȳ->LowerTriangular(Ȳ) - else - ∂F = Ȳ->UpperTriangular(Ȳ') + function getproperty_cholesky_pullback(Ȳ) + if x === :U + if F.uplo === 'U' + ∂F = @thunk UpperTriangular(Ȳ) + else + ∂F = @thunk LowerTriangular(Ȳ') + end + elseif x === :L + if F.uplo === 'L' + ∂F = @thunk LowerTriangular(Ȳ) + else + ∂F = @thunk UpperTriangular(Ȳ') + end end + return NO_FIELDS, ∂F, DNE() end - return getproperty(F, x), (Rule(∂F), DNERule()) + return getproperty(F, x), getproperty_cholesky_pullback end # See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, @@ -184,7 +196,7 @@ end """ chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) -Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly +Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index d2ee20309..3e8934afc 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -1,18 +1,34 @@ # Structured matrices + ##### ##### `Diagonal` ##### -rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag) +function rrule(::Type{<:Diagonal}, d::AbstractVector) + function Diagonal_pullback(ȳ) + return (NO_FIELDS, @thunk(diag(ȳ))) + end + return Diagonal(d), Diagonal_pullback +end -rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal) +function rrule(::typeof(diag), A::AbstractMatrix) + function diag_pullback(ȳ) + return (NO_FIELDS, @thunk(Diagonal(ȳ))) + end + return diag(A), diag_pullback +end ##### ##### `Symmetric` ##### -rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) +function rrule(::Type{<:Symmetric}, A::AbstractMatrix) + function Symmetric_pullback(ȳ) + return (NO_FIELDS, @thunk(_symmetric_back(ȳ))) + end + return Symmetric(A), Symmetric_pullback +end _symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ @@ -21,27 +37,81 @@ _symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ ##### `Adjoint` ##### -# TODO: Deal with complex-valued arrays as well -rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), Rule(adjoint) -rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), Rule(vec∘adjoint) +# ✖️✖️✖️TODO: Deal with complex-valued arrays as well +function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) + function Adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(adjoint(ȳ))) + end + return Adjoint(A), Adjoint_pullback +end + +function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) + function Adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(vec(adjoint(ȳ)))) + end + return Adjoint(A), Adjoint_pullback +end + +function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) + function adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(adjoint(ȳ))) + end + return adjoint(A), adjoint_pullback +end -rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), Rule(adjoint) -rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), Rule(vec∘adjoint) +function rrule(::typeof(adjoint), A::AbstractVector{<:Real}) + function adjoint_pullback(ȳ) + return (NO_FIELDS, @thunk(vec(adjoint(ȳ)))) + end + return adjoint(A), adjoint_pullback +end ##### ##### `Transpose` ##### -rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), Rule(transpose) -rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), Rule(vec∘transpose) +function rrule(::Type{<:Transpose}, A::AbstractMatrix) + function Transpose_pullback(ȳ) + return (NO_FIELDS, @thunk transpose(ȳ)) + end + return Transpose(A), Transpose_pullback +end + +function rrule(::Type{<:Transpose}, A::AbstractVector) + function Transpose_pullback(ȳ) + return (NO_FIELDS, @thunk vec(transpose(ȳ))) + end + return Transpose(A), Transpose_pullback +end + +function rrule(::typeof(transpose), A::AbstractMatrix) + function transpose_pullback(ȳ) + return (NO_FIELDS, @thunk transpose(ȳ)) + end + return transpose(A), transpose_pullback +end -rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), Rule(transpose) -rrule(::typeof(transpose), A::AbstractVector) = transpose(A), Rule(vec∘transpose) +function rrule(::typeof(transpose), A::AbstractVector) + function transpose_pullback(ȳ) + return (NO_FIELDS, @thunk vec(transpose(ȳ))) + end + return transpose(A), transpose_pullback +end ##### ##### Triangular matrices ##### -rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix) +function rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) + function UpperTriangular_pullback(ȳ) + return (NO_FIELDS, @thunk Matrix(ȳ)) + end + return UpperTriangular(A), UpperTriangular_pullback +end -rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix) +function rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) + function LowerTriangular_pullback(ȳ) + return (NO_FIELDS, @thunk Matrix(ȳ)) + end + return LowerTriangular(A), LowerTriangular_pullback +end diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 2be434ce3..0c40fa36b 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -9,13 +9,27 @@ _denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1) # TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36 function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) - _, dx = rrule(sum, x; dims=dims) + y_sum, sum_pullback = rrule(sum, x; dims=dims) n = _denom(x, dims) - return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n) + function mean_pullback(ȳ) + ∂x = Thunk() do + _, ∂sum_x = sum_pullback(ȳ) + extern(∂sum_x) / n + end + return (NO_FIELDS, ∂x) + end + return y_sum / n, mean_pullback end function rrule(::typeof(mean), f, x::AbstractArray{<:Real}) - _, (_, dx) = rrule(sum, f, x) + y_sum, sum_pullback = rrule(sum, f, x) n = _denom(x, :) - return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n)) + function mean_pullback(ȳ) + ∂x = Thunk() do + _, _, ∂sum_x = sum_pullback(ȳ) + extern(∂sum_x) / n + end + return (NO_FIELDS, DNE(), ∂x) + end + return y_sum / n, mean_pullback end diff --git a/test/helper_functions.jl b/test/helper_functions.jl index 7d3a8d170..062cb631a 100644 --- a/test/helper_functions.jl +++ b/test/helper_functions.jl @@ -19,10 +19,11 @@ end @testset "_update! NamedTuple" begin X = (A=[1 0; 0 1], B=[2 2; 2 2]) + old_X = deepcopy(X) Y = deepcopy(X) - @test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4]) - @test X.A != Y.A - @test X.B != Y.B + @test ChainRules._update!(X, Y, :A) == (A=[2 0; 0 2], B=[2 2; 2 2]) + @test X.A != old_X.A + @test X.B == old_X.B end @testset "_checked_rrule" begin try diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 3a70b9ecd..c113b6698 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -1,20 +1,24 @@ @testset "reshape" begin rng = MersenneTwister(1) A = randn(rng, 4, 5) - B, (dA, dd) = rrule(reshape, A, (5, 4)) + B, pullback = rrule(reshape, A, (5, 4)) @test B == reshape(A, (5, 4)) - @test dd isa ChainRules.DNERule Ȳ = randn(rng, 4, 5) - Ā = dA(Ȳ) - @test Ā == reshape(Ȳ, (5, 4)) - B, (dA, dd1, dd2) = rrule(reshape, A, 5, 4) + (s̄, Ā, d̄) = pullback(Ȳ) + @test s̄ == NO_FIELDS + @test d̄ isa DNE + @test extern(Ā) == reshape(Ȳ, (5, 4)) + + B, pullback = rrule(reshape, A, 5, 4) @test B == reshape(A, 5, 4) - @test dd1 isa ChainRules.DNERule - @test dd2 isa ChainRules.DNERule + Ȳ = randn(rng, 4, 5) - Ā = dA(Ȳ) - @test Ā == reshape(Ȳ, 5, 4) + (s̄, Ā, d̄1, d̄2) = pullback(Ȳ) + @test s̄ == NO_FIELDS + @test d̄1 isa DNE + @test d̄2 isa DNE + @test extern(Ā) == reshape(Ȳ, 5, 4) end @testset "hcat" begin @@ -22,12 +26,14 @@ end A = randn(rng, 3, 2) B = randn(rng, 3) C = randn(rng, 3, 3) - H, (dA, dB, dC) = rrule(hcat, A, B, C) + H, pullback = rrule(hcat, A, B, C) @test H == hcat(A, B, C) H̄ = randn(rng, 3, 6) - @test dA(H̄) ≈ view(H̄, :, 1:2) - @test dB(H̄) ≈ view(H̄, :, 3) - @test dC(H̄) ≈ view(H̄, :, 4:6) + (ds, dA, dB, dC) = pullback(H̄) + @test ds == NO_FIELDS + @test dA ≈ view(H̄, :, 1:2) + @test dB ≈ view(H̄, :, 3) + @test dC ≈ view(H̄, :, 4:6) end @testset "vcat" begin @@ -35,22 +41,28 @@ end A = randn(rng, 2, 4) B = randn(rng, 1, 4) C = randn(rng, 3, 4) - V, (dA, dB, dC) = rrule(vcat, A, B, C) + V, pullback = rrule(vcat, A, B, C) @test V == vcat(A, B, C) V̄ = randn(rng, 6, 4) - @test dA(V̄) ≈ view(V̄, 1:2, :) - @test dB(V̄) ≈ view(V̄, 3:3, :) - @test dC(V̄) ≈ view(V̄, 4:6, :) + (ds, dA, dB, dC) = pullback(V̄) + @test ds == NO_FIELDS + @test dA ≈ view(V̄, 1:2, :) + @test dB ≈ view(V̄, 3:3, :) + @test dC ≈ view(V̄, 4:6, :) end @testset "fill" begin - y, (dv, dd) = rrule(fill, 44, 4) + y, pullback = rrule(fill, 44, 4) @test y == [44, 44, 44, 44] - @test dd isa ChainRules.DNERule - @test dv(ones(Int, 4)) == 4 + (ds, dv, dd) = pullback(ones(4)) + @test ds === NO_FIELDS + @test dd isa DNE + @test extern(dv) == 4 - y, (dv, dd) = rrule(fill, 2.0, (3, 3, 3)) + y, pullback = rrule(fill, 2.0, (3, 3, 3)) @test y == fill(2.0, (3, 3, 3)) - @test dd isa ChainRules.DNERule - @test dv(ones(3, 3, 3)) ≈ 27.0 + (ds, dv, dd) = pullback(ones(3, 3, 3)) + @test ds === NO_FIELDS + @test dd isa DNE + @test dv ≈ 27.0 end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a4d0d4bd8..c8fd73827 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -53,24 +53,38 @@ end @testset "Multivariate" begin x, y = rand(2) - ratan = atan(x, y) # https://en.wikipedia.org/wiki/Atan2 - u = x^2 + y^2 - datan = y/u - 2x/u - r, df = frule(atan, x, y) - @test r === ratan - @test df(1, 2) === datan - r, (df1, df2) = rrule(atan, x, y) - @test r === ratan - @test df1(1) + df2(2) === datan - - rsincos = sincos(x) - dsincos = cos(x) - 2sin(x) - r, (df1, df2) = frule(sincos, x) - @test r === rsincos - @test df1(1) + df2(2) === dsincos - r, df = rrule(sincos, x) - @test r === rsincos - @test df(1, 2) === dsincos + @testset "atan2" begin + # https://en.wikipedia.org/wiki/Atan2 + ratan = atan(x, y) + u = x^2 + y^2 + datan = y/u - 2x/u + + r, pushforward = frule(atan, x, y) + @test r === ratan + @test pushforward(NamedTuple(), 1, 2) === datan + + r, pullback = rrule(atan, x, y) + @test r === ratan + dself, df1, df2 = pullback(1) + @test dself == NO_FIELDS + @test df1 + 2df2 === datan + end + + @testset "sincos" begin + rsincos = sincos(x) + dsincos = cos(x) - 2sin(x) + + r, pushforward = frule(sincos, x) + @test r === rsincos + df1, df2 = pushforward(NamedTuple(), 1) + @test df1 + 2df2 === dsincos + + r, pullback = rrule(sincos, x) + @test r === rsincos + ds, df = pullback(1, 2) + @test df === dsincos + @test ds === NO_FIELDS + end end end # Trig @@ -116,22 +130,26 @@ @testset "*(x, y)" begin x, y = rand(3, 2), rand(2, 5) - z, (dx, dy) = rrule(*, x, y) + z, pullback = rrule(*, x, y) @test z == x * y z̄ = rand(3, 5) + (ds, dx, dy) = pullback(z̄) + + @test ds === NO_FIELDS - @test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄)) - @test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄)) + @test extern(dx) == extern(accumulate(zeros(3, 2), dx)) + @test extern(dy) == extern(accumulate(zeros(2, 5), dy)) - test_accumulation(rand(3, 2), dx, z̄, z̄ * y') - test_accumulation(rand(2, 5), dy, z̄, x' * z̄) + test_accumulation(rand(3, 2), dx) + test_accumulation(rand(2, 5), dy) end @testset "hypot(x, y)" begin x, y = rand(2) - h, dxy = frule(hypot, x, y) + h, pushforward = frule(hypot, x, y) + dxy(x, y) = pushforward(NamedTuple(), x, y) @test extern(dxy(One(), Zero())) === x / h @test extern(dxy(Zero(), One())) === y / h @@ -149,7 +167,6 @@ @testset "identity" begin rng = MersenneTwister(1) - n = 4 rrule_test(identity, randn(rng), (randn(rng), randn(rng))) rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4))) end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 74660db9e..9dc592008 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -1,20 +1,32 @@ @testset "broadcast" begin - @testset "Misc. Tests" begin - @testset "sin.(x)" begin + @testset "sin.(x)" begin + @testset "rrule" begin x = rand(3, 3) - y, (dsin, dx) = rrule(broadcast, sin, x) - + y, pullback = rrule(broadcast, sin, x) @test y == sin.(x) - @test extern(dx(One())) == cos.(x) + (dself, dsin, dx) = pullback(One()) + @test dself == NO_FIELDS + @test dsin == DNE() + @test extern(dx) == cos.(x) x̄, ȳ = rand(), rand() + ∂x = pullback(ȳ)[3] @test isequal( - extern(ChainRules.accumulate(x̄, dx, ȳ)), + extern(ChainRules.accumulate(x̄, ∂x)), x̄ .+ ȳ .* cos.(x) ) x̄, ȳ = Zero(), rand(3, 3) - @test extern(accumulate(x̄, dx, ȳ)) == ȳ .* cos.(x) + ∂x = pullback(ȳ)[3] + @test extern(extern(accumulate(x̄, ∂x))) == ȳ .* cos.(x) + end + @testset "frule" begin + x = rand(3, 3) + y, pushforward = frule(broadcast, sin, x) + @test y == sin.(x) + + ẏ = pushforward(NamedTuple(), NamedTuple(), One()) + @test extern(ẏ) == cos.(x) end end end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 6c3892067..a8699663f 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -15,11 +15,12 @@ vx = randn(rng, n) ȳ = randn(rng) rrule_test(mapreduce, ȳ, (sin, nothing), (+, nothing), (x, vx)) + # With keyword arguments (not yet supported in rrule_test) X = randn(rng, n, n) - y, (_, _, dx) = rrule(mapreduce, abs2, +, X; dims=2) + y, pullback = rrule(mapreduce, abs2, +, X; dims=2) ȳ = randn(rng, size(y)) - x̄_ad = dx(ȳ) + (_, _, _, x̄_ad) = pullback(ȳ) x̄_fd = j′vp(central_fdm(5, 1), x->mapreduce(abs2, +, x; dims=2), ȳ, X) @test x̄_ad ≈ x̄_fd atol=1e-9 rtol=1e-9 end @@ -56,10 +57,10 @@ @testset "keyword arguments" begin rng = MersenneTwister(33) n = 4 - X = randn(rng, n, n) - y, dX = rrule(sum, X; dims=2) + X = randn(rng, n, n+1) + y, pullback = rrule(sum, X; dims=2) ȳ = randn(rng, size(y)) - x̄_ad = dX(ȳ) + _, x̄_ad = pullback(ȳ) x̄_fd = j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X) @test x̄_ad ≈ x̄_fd atol=1e-9 rtol=1e-9 end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index dcc861b5b..beacac383 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -66,17 +66,20 @@ end end @testset "$f" for f in [/, \] rng = MersenneTwister(42) - for n in 3:5, m in 3:5 - A = randn(rng, m, n) - B = randn(rng, m, n) - Ȳ = randn(rng, size(f(A, B))) - rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n))) + @testset "Matrix" begin + for n in 3:5, m in 3:5 + A = randn(rng, m, n) + B = randn(rng, m, n) + Ȳ = randn(rng, size(f(A, B))) + rrule_test(f, Ȳ, (A, randn(rng, m, n)), (B, randn(rng, m, n))) + end + end + @testset "Vector" begin + x = randn(rng, 10) + y = randn(rng, 10) + ȳ = randn(rng, size(f(x, y))...) + rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10))) end - # Vectors - x = randn(rng, 10) - y = randn(rng, 10) - ȳ = randn(rng, size(f(x, y))...) - rrule_test(f, ȳ, (x, randn(rng, 10)), (y, randn(rng, 10))) if f == (/) @testset "$T on the RHS" for T in (Diagonal, UpperTriangular, LowerTriangular) RHS = T(randn(rng, T == Diagonal ? 10 : (10, 10))) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 7713a59f5..0b29cec96 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -2,33 +2,49 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @testset "Factorizations" begin @testset "svd" begin - rng = MersenneTwister(2) + rng = MersenneTwister(3) for n in [4, 6, 10], m in [3, 5, 10] X = randn(rng, n, m) - F, dX = rrule(svd, X) + F, dX_pullback = rrule(svd, X) for p in [:U, :S, :V] - Y, (dF, dp) = rrule(getproperty, F, p) - @test dp isa ChainRules.DNERule + Y, dF_pullback = rrule(getproperty, F, p) Ȳ = randn(rng, size(Y)...) - X̄_ad = dX(dF(Ȳ)) + + dself1, dF, dp = dF_pullback(Ȳ) + @test dself1 === NO_FIELDS + @test dp === DNE() + + ΔF = extern(dF) + dself2, dX = dX_pullback(ΔF) + @test dself2 === NO_FIELDS + X̄_ad = extern(dX) X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X) @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 end - @test_throws ArgumentError rrule(getproperty, F, :Vt) + @testset "Vt" begin + Y, dF_pullback = rrule(getproperty, F, :Vt) + Ȳ = randn(rng, size(Y)...) + @test_throws ArgumentError dF_pullback(Ȳ) + end end + @testset "accumulate!" begin X = [1.0 2.0; 3.0 4.0; 5.0 6.0] - F, dX = rrule(svd, X) + F, dX_pullback = rrule(svd, X) X̄ = (U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) for p in [:U, :S, :V] - Y, (dF, _) = rrule(getproperty, F, p) + Y, dF_pullback = rrule(getproperty, F, p) Ȳ = ones(size(Y)...) - ChainRules.accumulate!(X̄, dF, Ȳ) + (dself, dF, dp) = dF_pullback(Ȳ) + @test dself === NO_FIELDS + @test dp === DNE() + ChainRules.accumulate!(X̄, dF) end @test X̄.U ≈ ones(3, 2) atol=1e-6 @test X̄.S ≈ ones(2) atol=1e-6 @test X̄.V ≈ ones(2, 2) atol=1e-6 end + @testset "Helper functions" begin X = randn(rng, 10, 10) Y = randn(rng, 10, 10) @@ -42,17 +58,22 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @testset "the thing" begin X = generate_well_conditioned_matrix(rng, 10) V = generate_well_conditioned_matrix(rng, 10) - F, dX = rrule(cholesky, X) + F, dX_pullback = rrule(cholesky, X) for p in [:U, :L] - Y, (dF, dp) = rrule(getproperty, F, p) - @test dp isa ChainRules.DNERule + Y, dF_pullback = rrule(getproperty, F, p) Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y))) + (dself, dF, dp) = dF_pullback(Ȳ) + @test dself === NO_FIELDS + @test dp === DNE() + # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` # machinery from FiniteDifferences because that isn't set up to respect # necessary special properties of the input. In the case of the Cholesky # factorization, we need the input to be Hermitian. - X̄_ad = dot(dX(dF(Ȳ)), V) - X̄_fd = central_fdm(5, 1)() do ε + ΔF = extern(dF) + _, dX = dX_pullback(ΔF) + X̄_ad = dot(extern(dX), V) + X̄_fd = _fdm() do ε dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) end @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 diff --git a/test/rulesets/Statistics/statistics.jl b/test/rulesets/Statistics/statistics.jl index b36d5fbe9..c31af6f15 100644 --- a/test/rulesets/Statistics/statistics.jl +++ b/test/rulesets/Statistics/statistics.jl @@ -1,11 +1,33 @@ @testset "mean" begin rng = MersenneTwister(999) n = 9 - rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n))) - X = randn(rng, n, n) - y, dX = rrule(mean, X; dims=1) - ȳ = randn(rng, size(y)) - X̄_ad = dX(ȳ) - X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X) - @test X̄_ad ≈ X̄_fd rtol=1e-9 atol=1e-9 + + @testset "Basic" begin + rrule_test( + mean, + randn(rng), + (randn(rng, n), + randn(rng, n)) + ) + end + + @testset "with function arg" begin + rrule_test( + mean, + randn(rng), + (abs2, nothing), + (randn(rng, n), + randn(rng, n)) + ) + end + + @testset "with dims kwargs" begin + X = randn(rng, n, n+1) + y, mean_pullback = rrule(mean, X; dims=1) + ȳ = randn(rng, size(y)) + _, dX = mean_pullback(ȳ) + X̄_ad = extern(dX) + X̄_fd = j′vp(_fdm, x->mean(x, dims=1), ȳ, X) + @test X̄_ad ≈ X̄_fd rtol=1e-9 atol=1e-9 + end end diff --git a/test/runtests.jl b/test/runtests.jl index 7512218e3..7379bb2f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,7 +12,7 @@ using Test # For testing purposes we use a lot of using ChainRulesCore: cast, extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, - Zero, One, Casted, DNE, Thunk, DNERule, AbstractDifferential + Zero, One, Casted, DNE, Thunk, AbstractDifferential include("test_util.jl") diff --git a/test/test_util.jl b/test/test_util.jl index 571b9c36f..6e09cbcf8 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -5,6 +5,7 @@ using ChainRulesCore: AbstractDifferential const _fdm = central_fdm(5, 1) + """ test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), test_wirtinger=x isa Complex, kwargs...) @@ -20,16 +21,28 @@ at input point `x` to confirm that there are correct ChainRules provided. All keyword arguments except for `fdm` and `test_wirtinger` are passed to `isapprox`. """ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa Complex, kwargs...) + ensure_not_running_on_functor(f, "test_scalar") + @testset "$f at $x, $(nameof(rule))" for rule in (rrule, frule) res = rule(f, x) @test res !== nothing # Check the rule was defined - fx, ∂x = res + fx, prop_rule = res @test fx == f(x) # Check we still get the normal value, right + if rule == rrule + ∂self, ∂x = prop_rule(1) + @test ∂self === NO_FIELDS + else # rule == frule + # Got to input extra first aguement for internals + # But it is only a dummy since this is not a functor + ∂x, = prop_rule(NamedTuple(), 1) + end + + # Check that we get the derivative right: if !test_wirtinger @test isapprox( - ∂x(1), fdm(f, x); + ∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs... ) else @@ -39,17 +52,28 @@ function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, test_wirtinger=x isa ∂ = 0.5(∂Re - im*∂Im) ∂̅ = 0.5(∂Re + im*∂Im) @test isapprox( - wirtinger_primal(∂x(1)), ∂; + wirtinger_primal(∂x), ∂; rtol=rtol, atol=atol, kwargs... ) @test isapprox( - wirtinger_conjugate(∂x(1)), ∂̅; + wirtinger_conjugate(∂x), ∂̅; rtol=rtol, atol=atol, kwargs... ) end end end +function ensure_not_running_on_functor(f, name) + # if x itself is a Type, then it is a constructor, thus not a functor. + # This also catchs UnionAll constructors which have a `:var` and `:body` fields + f isa Type && return + + if fieldcount(typeof(f)) > 0 + throw(ArgumentError( + "$name cannot be used on closures/functors (such as $f)" + )) + end +end """ frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) @@ -66,14 +90,18 @@ function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) end function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) + ensure_not_running_on_functor(f, "frule_test") xs, ẋs = collect(zip(xẋs...)) - Ω, dΩ_rule = ChainRules.frule(f, xs...) + Ω, pushforward = ChainRules.frule(f, xs...) @test f(xs...) == Ω + dΩ_ad = pushforward(NamedTuple(), ẋs...) - dΩ_ad, dΩ_fd = dΩ_rule(ẋs...), jvp(fdm, xs->f(xs...), (xs, ẋs)) + # Correctness testing via finite differencing. + dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) @test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...) end + """ rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) @@ -86,17 +114,20 @@ end All keyword arguments except for `fdm` are passed to `isapprox`. """ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) + ensure_not_running_on_functor(f, "rrule_test") + # Check correctness of evaluation. - fx, dx = ChainRules.rrule(f, x) + fx, pullback = ChainRules.rrule(f, x) @test fx ≈ f(x) - + (∂self, x̄_ad) = pullback(ȳ) + @test ∂self === NO_FIELDS # No internal fields # Correctness testing via finite differencing. - x̄_ad, x̄_fd = dx(ȳ), j′vp(fdm, f, ȳ, x) + x̄_fd = j′vp(fdm, f, ȳ, x) @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) # Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct. - test_accumulation(x̄, dx, ȳ, x̄_ad) - test_accumulation(Zero(), dx, ȳ, x̄_ad) + test_accumulation(x̄, x̄_ad) + test_accumulation(Zero(), x̄_ad) end function _make_fdm_call(fdm, f, ȳ, xs, ignores) @@ -127,15 +158,20 @@ function _make_fdm_call(fdm, f, ȳ, xs, ignores) end function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) + ensure_not_running_on_functor(f, "rrule_test") + # Check correctness of evaluation. xs, x̄s = collect(zip(xx̄s...)) - y, rules = rrule(f, xs...) + y, pullback = rrule(f, xs...) @test f(xs...) == y + @assert !(isa(ȳ, Thunk)) + ∂s = pullback(ȳ) + ∂self = ∂s[1] + x̄s_ad = ∂s[2:end] + @test ∂self === NO_FIELDS + # Correctness testing via finite differencing. - x̄s_ad = map(rules) do rule - rule isa DNERule ? DNE() : rule(ȳ) - end x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing) for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd) if x̄_fd === nothing @@ -147,10 +183,10 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm end # Assuming the above to be correct, check that other ChainRules mechanisms are correct. - for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad) + for (x̄, x̄_ad) in zip(x̄s, x̄s_ad) x̄ === nothing && continue - test_accumulation(x̄, rule, ȳ, x̄_ad) - test_accumulation(Zero(), rule, ȳ, x̄_ad) + test_accumulation(x̄, x̄_ad) + test_accumulation(Zero(), x̄_ad) end end @@ -167,51 +203,51 @@ function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) return isapprox(extern(d_ad), d_fd; kwargs...) end -function test_accumulation(x̄, dx, ȳ, partial) - @test all(extern(x̄ + partial) .≈ extern(x̄) .+ extern(partial)) - test_accumulate(x̄, dx, ȳ, partial) - test_accumulate!(x̄, dx, ȳ, partial) - test_store!(x̄, dx, ȳ, partial) - return nothing +function test_accumulation(x̄, ∂x) + @test all(extern(x̄ + ∂x) .≈ extern(x̄) .+ extern(∂x)) + test_accumulate(x̄, ∂x) + test_accumulate!(x̄, ∂x) + test_store!(x̄, ∂x) end -function test_accumulate(x̄::Zero, dx, ȳ, partial) - @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(partial) - return nothing +function test_accumulate(x̄::Zero, ∂x) + @test extern(accumulate(x̄, ∂x)) ≈ extern(∂x) end -function test_accumulate(x̄::Number, dx, ȳ, partial) - @test extern(accumulate(x̄, dx, ȳ)) ≈ extern(x̄) + extern(partial) - return nothing +function test_accumulate(x̄::Number, ∂x) + @test extern(accumulate(x̄, ∂x)) ≈ extern(x̄) + extern(∂x) end -function test_accumulate(x̄::AbstractArray, dx, ȳ, partial) +function test_accumulate(x̄::AbstractArray, ∂x) x̄_old = copy(x̄) - @test all(extern(accumulate(x̄, dx, ȳ)) .≈ (extern(x̄) .+ extern(partial))) - @test x̄ == x̄_old - return nothing + @test all(extern(accumulate(x̄, ∂x)) .≈ (extern(x̄) .+ extern(∂x))) + @test x̄ == x̄_old # make sure didn't mutate x̄ end -test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing +test_accumulate!(x̄::Zero, ∂x) = nothing -function test_accumulate!(x̄::Number, dx, ȳ, partial) - @test accumulate!(x̄, dx, ȳ) ≈ accumulate(x̄, dx, ȳ) - return nothing +function test_accumulate!(x̄::Number, ∂x) + # This case won't have been inplace as `Number` is immutable + @test accumulate!(x̄, ∂x) ≈ accumulate(x̄, ∂x) end -function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial) +function test_accumulate!(x̄::AbstractArray, ∂x) x̄_copy = copy(x̄) - accumulate!(x̄_copy, dx, ȳ) - @test extern(x̄_copy) ≈ (extern(x̄) .+ extern(partial)) - return nothing + + accumulate!(x̄_copy, ∂x) # this should have actually been in-place + @test extern(x̄_copy) ≈ (extern(x̄) .+ extern(∂x)) end -test_store!(x̄::Zero, dx, ȳ, partial) = nothing -test_store!(x̄::Number, dx, ȳ, partial) = nothing +test_store!(x̄::Zero, ∂x) = nothing +test_store!(x̄::Number, ∂x) = nothing -function test_store!(x̄::AbstractArray, dx, ȳ, partial) - x̄_copy = copy(x̄) - store!(x̄_copy, dx, ȳ) - @test all(x̄_copy .≈ extern(partial)) - return nothing +function test_store!(x̄::AbstractArray, ∂x) + x̄_store = copy(x̄) + store!(x̄_store, ∂x) + @test x̄_store ≈ extern(∂x) + + # store! is the same as `accumulate!` to a zero array + x̄_acc = false.*x̄ + accumulate!(x̄_acc, ∂x) + @test x̄_acc ≈ x̄_store end