Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
366d983
[WIP] include dervative WRT self.
oxinabox Aug 27, 2019
4e9e330
wip
oxinabox Aug 28, 2019
4f641d6
WIP
oxinabox Aug 28, 2019
c575895
[WIP] make changes to all the rules to return WRT self
oxinabox Aug 28, 2019
0624600
comment out more accumulate
oxinabox Aug 29, 2019
6c7478d
WIP:
oxinabox Aug 30, 2019
05d0c06
all real scalar rules working
oxinabox Sep 2, 2019
7818a63
Wirtinger scalars passing
oxinabox Sep 2, 2019
1a43009
all tests in tests/rulesets/Base/base.jl
oxinabox Sep 3, 2019
6592c95
Fixup Base tests to match frule not returning a tuple
oxinabox Sep 5, 2019
74434c5
attay test passing
oxinabox Sep 5, 2019
1937ec1
Broadcast fixed
oxinabox Sep 5, 2019
facf994
WIP fixing up mapreduce file
oxinabox Sep 5, 2019
9849149
make structured and dense rulesets pass
oxinabox Sep 6, 2019
908b7ea
BLAS written but need to re-sort out update rules before done proper
oxinabox Sep 6, 2019
68c89e2
BLAS rules working but update accumulation inplace is diabled
oxinabox Sep 10, 2019
33ae54e
Factorizations working
oxinabox Sep 10, 2019
dff859a
Make statistics work
oxinabox Sep 10, 2019
a434f2a
remove double extern
oxinabox Sep 10, 2019
aafadfa
fix bad rebase
oxinabox Sep 11, 2019
4b71325
WIP use InplaceableThunks for updating rules
oxinabox Sep 11, 2019
c971052
make factorizations accumulate! right
oxinabox Sep 11, 2019
8980926
style and typos
oxinabox Sep 17, 2019
8c1e418
use _fdm rather than making a new central_fdm
oxinabox Sep 17, 2019
b5d3ba9
set version correctly
oxinabox Sep 18, 2019
9eee9db
name some pullbacks
oxinabox Sep 18, 2019
1288748
name more propagators
oxinabox Sep 18, 2019
f2e5705
More named propagators
oxinabox Sep 18, 2019
a566a0d
Name more propagators
oxinabox Sep 18, 2019
85d0fa0
More named propagators
oxinabox Sep 18, 2019
0f9b3ac
delete extra unused _update! methods
oxinabox Sep 18, 2019
cc9028d
name more progators
oxinabox Sep 18, 2019
add271f
fix up typos and extern new thunks in tests
oxinabox Sep 18, 2019
3dbdbfc
test nonsquares
oxinabox Sep 18, 2019
6ca8deb
more named propagators
oxinabox Sep 18, 2019
320dba0
Apply suggestions from code review
oxinabox Sep 19, 2019
c3c20f6
more named propagators
oxinabox Sep 19, 2019
bcb24db
Update test/rulesets/Base/base.jl
oxinabox Sep 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.1.1"
version = "0.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "^0.2"
ChainRulesCore = "^0.3"
FiniteDifferences = "^0.7"
julia = "^1.0"

Expand Down
24 changes: 11 additions & 13 deletions src/helper_functions.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# Special purpose updating for operations which can be done in-place. This function is
# just internal and free-form; it is not a method of `accumulate!` directly as it does
# not adhere to the expected method signature form, i.e. `accumulate!(value, rule, args)`.
# Instead it's `_update!(old, new, extrastuff...)` and is not specific to any particular
# rule.
# Internal helpers for defining the `add!` field of an `InplaceableThunk`

_update!(x, y) = x + y
_update!(x::Array{T,N}, y::AbstractArray{T,N}) where {T,N} = x .+= y
Expand All @@ -11,20 +7,22 @@ _update!(x, ::Zero) = x
_update!(::Zero, y) = y
_update!(::Zero, ::Zero) = Zero()

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}) where Ns
return NamedTuple{Ns}(map(p->_update!(getproperty(x, p), getproperty(y, p)), Ns))
end

function _update!(x::NamedTuple, y, p::Symbol)
new = NamedTuple{(p,)}((_update!(getproperty(x, p), y),))
y = extern(y)
yp = getproperty(y, p)
xp = getproperty(x, p)
new_xp = _update!(xp, yp)
new = NamedTuple{(p,)}((new_xp,))
return merge(x, new)
end

function _update!(x::NamedTuple{Ns}, y::NamedTuple{Ns}, p::Symbol) where Ns
return _update!(x, getproperty(y, p), p)
end

"""
_checked_rrule

like `rrule` but throws an error if the `rrule` is not defined.
Rather than returning `nothing`
"""
function _checked_rrule(f, args...; kwargs...)
r = rrule(f, args...; kwargs...)
r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...)
Expand Down
64 changes: 41 additions & 23 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,74 @@
#####

function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}})
return reshape(A, dims), (Rule(Ȳ->reshape(Ȳ, dims)), DNERule())
function reshape_pullback(Ȳ)
return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DNE())
end
return reshape(A, dims), reshape_pullback
end

function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
Y, (rule, _) = rrule(reshape, A, dims)
return Y, (rule, fill(DNERule(), length(dims))...)
function reshape_pullback(Ȳ)
∂A = @thunk(reshape(Ȳ, dims))
return (NO_FIELDS, ∂A, fill(DNE(), length(dims))...)
end
return reshape(A, dims...), reshape_pullback
end

#####
##### `hcat` (🐈)
#####

function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...)
Y = hcat(A, Bs...)
Xs = (A, Bs...)
rules = ntuple(length(Bs) + 1) do i
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
u = l + size(Xs[i], 2)
dim = u > l + 1 ? (l+1:u) : u
# NOTE: The copy here is defensive, since `selectdim` returns a view which we can
# materialize with `copy`
Rule(Ȳ->copy(selectdim(Ȳ, 2, dim)))
function hcat_pullback(Ȳ)
Xs = (A, Bs...)
ntuple(length(Bs) + 2) do full_i
full_i == 1 && return NO_FIELDS

i = full_i - 1
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
u = l + size(Xs[i], 2)
dim = u > l + 1 ? (l+1:u) : u
# NOTE: The copy here is defensive, since `selectdim` returns a view which we can
# materialize with `copy`
copy(selectdim(Ȳ, 2, dim))
end
end
return Y, rules
return hcat(A, Bs...), hcat_pullback
end

#####
##### `vcat`
#####

function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...)
Y = vcat(A, Bs...)
n = size(A, 1)
∂A = Rule(Ȳ->copy(selectdim(Ȳ, 1, 1:n)))
∂Bs = ntuple(length(Bs)) do i
l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0)
u = l + size(Bs[i], 1)
Rule(Ȳ->copy(selectdim(Ȳ, 1, l+1:u)))
function vcat_pullback(Ȳ)
n = size(A, 1)
∂A = copy(selectdim(Ȳ, 1, 1:n))
∂Bs = ntuple(length(Bs)) do i
l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0)
u = l + size(Bs[i], 1)
copy(selectdim(Ȳ, 1, l+1:u))
end
return (NO_FIELDS, ∂A, ∂Bs...)
end
return Y, (∂A, Bs...)
return vcat(A, Bs...), vcat_pullback
end

#####
##### `fill`
#####

function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}})
return fill(value, dims), (Rule(sum), DNERule())
function fill_pullback(Ȳ)
return (NO_FIELDS, @thunk(sum(Ȳ)), DNE())
end
return fill(value, dims), fill_pullback
end

function rrule(::typeof(fill), value::Any, dims::Int...)
return fill(value, dims), (Rule(sum), ntuple(_->DNERule(), length(dims))...)
function fill_pullback(Ȳ)
return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DNE(), length(dims))...)
end
return fill(value, dims), fill_pullback
end
34 changes: 27 additions & 7 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,30 @@

# product rule requires special care for arguments where `mul` is non-commutative

frule(::typeof(*), x::Number, y::Number) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy)

rrule(::typeof(*), x::Number, y::Number) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ))

frule(::typeof(identity), x) = x, Rule(identity)

rrule(::typeof(identity), x) = x, Rule(identity)
function frule(::typeof(*), x::Number, y::Number)
function times_pushforward(_, Δx, Δy)
return Δx * y + x * Δy
end
return x * y, times_pushforward
end

function rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
end
return x * y, times_pullback
end

function frule(::typeof(identity), x)
function identity_pushforward(_, ẏ)
return ẏ
end
return x, identity_pushforward
end

function rrule(::typeof(identity), x)
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
end
return x, identity_pullback
end
14 changes: 10 additions & 4 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,26 @@ without relying on inference hacks unless we have something akin to
https://github.com/JuliaLang/julia/issues/22129.
=#
function _cast_diff(f, x)
element_rule = u -> begin
function element_rule(u)
fu, du = frule(f, u)
fu, extern(du(One()))
fu, extern(du(NamedTuple(), One()))
end
results = broadcast(element_rule, x)
return first.(results), last.(results)
end

function frule(::typeof(broadcast), f, x)
Ω, ∂x = _cast_diff(f, x)
return Ω, Rule((_, Δx) -> Δx * cast(∂x))
function broadcast_pushforward(_, Δf, Δx)
return Δx * cast(∂x)
end
return Ω, broadcast_pushforward
end

function rrule(::typeof(broadcast), f, x)
values, derivs = _cast_diff(f, x)
return values, (DNERule(), Rule(ΔΩ -> ΔΩ * cast(derivs)))
function broadcast_pullback(ΔΩ)
return (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs)))
end
return values, broadcast_pullback
end
63 changes: 44 additions & 19 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@

function rrule(::typeof(map), f, xs...)
y = map(f, xs...)
∂xs = ntuple(length(xs)) do i
Rule() do ȳ
map(ȳ, xs...) do ȳi, xis...
_, ∂xis = _checked_rrule(f, xis...)
extern(∂xis[i](ȳi))
function map_pullback(ȳ)
ntuple(length(xs)+2) do full_i
full_i == 1 && return NO_FIELDS
full_i == 2 && return DNE()
i = full_i-2
@thunk map(ȳ, xs...) do ȳi, xis...
_, pullback = _checked_rrule(f, xis...)
∂xis = pullback(ȳi)
extern(∂xis[i+1]) #+1 to skp ∂self
end
end
end
return y, (DNERule(), ∂xs...)
return y, map_pullback
end

#####
Expand All @@ -26,15 +30,18 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr)
insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:))))
insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims)))
end
pullback_name = Symbol(mf, :_pullback)
body = quote
y = $call
∂x = Rule() do ȳ
broadcast(x, ȳ) do xi, ȳi
_, ∂xi = _checked_rrule(f, xi)
extern(∂xi(ȳi))
function $pullback_name(ȳ)
∂x = @thunk broadcast(x, ȳ) do xi, ȳi
_, pullback_f = _checked_rrule(f, xi)
_, ∂xi = pullback_f(ȳi)
extern(∂xi)
end
(NO_FIELDS, DNE(), DNE(), ∂x)
end
return y, (DNERule(), DNERule(), ∂x)
return y, $pullback_name
end
eval(Expr(:function, sig, body))
end
Expand All @@ -43,22 +50,40 @@ end
##### `sum`
#####

frule(::typeof(sum), x) = (sum(x), Rule(sum))
function frule(::typeof(sum), x)
function sum_pushforward(_, ẋ)
return sum(ẋ)
end
return sum(x), sum_pushforward
end

rrule(::typeof(sum), x) = (sum(x), Rule(cast))
function rrule(::typeof(sum), x)
function sum_pullback(ȳ)
return (NO_FIELDS, cast(ȳ))
end
return sum(x), sum_pullback
end

function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
y, (_, _, ∂x) = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
return y, (DNERule(), ∂x)
y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
function sum_pullback(ȳ)
NO_FIELDS, DNE(), last(mr_pullback(ȳ))
end
return y, sum_pullback
end

function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
y, (_, ∂x) = rrule(sum, identity, x; dims=dims)
return y, ∂x
y, inner_pullback = rrule(sum, identity, x; dims=dims)
function sum_pullback(ȳ)
NO_FIELDS, last(inner_pullback(ȳ))
end
return y, sum_pullback
end

function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
y = sum(abs2, x; dims=dims)
∂x = Rule(ȳ -> 2ȳ .* x)
return y, (DNERule(), ∂x)
function sum_abs2_pullback(ȳ)
return (NO_FIELDS, DNE(), @thunk(2ȳ .* x))
end
return y, sum_abs2_pullback
end
Loading