diff --git a/Project.toml b/Project.toml index a81fc0d4d..938852959 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.2.1-DEV" +version = "0.3.0" [compat] julia = "^1.0" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 2fd3302c9..118e7f841 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -1,13 +1,16 @@ module ChainRulesCore using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable -export AbstractRule, Rule, frule, rrule +export frule, rrule +export wirtinger_conjugate, wirtinger_primal, refine_differential export @scalar_rule, @thunk -export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule +export extern, cast, store! +export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk +export NO_FIELDS include("differentials.jl") include("differential_arithmetic.jl") -include("rule_types.jl") +include("operations.jl") include("rules.jl") include("rule_definition_tools.jl") end # module diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index d7a78777c..e65748d34 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered. Thus we can avoid any ambiguities. Notice: - The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :Any) + The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) Thus each of the @eval loops creating definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger) return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) end -for T in (:Casted, :Zero, :DNE, :One, :Thunk, :Any) +for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b @@ -47,7 +47,7 @@ end Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value)) Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value)) -for T in (:Zero, :DNE, :One, :Thunk, :Any) +for T in (:Zero, :DNE, :One, :AbstractThunk, :Any) @eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b)) @eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value)) @@ -58,7 +58,7 @@ end Base.:+(::Zero, b::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() -for T in (:DNE, :One, :Thunk, :Any) +for T in (:DNE, :One, :AbstractThunk, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a @@ -69,7 +69,7 @@ end Base.:+(::DNE, ::DNE) = DNE() Base.:*(::DNE, ::DNE) = DNE() -for T in (:One, :Thunk, :Any) +for T in (:One, :AbstractThunk, :Any) @eval Base.:+(::DNE, b::$T) = b @eval Base.:+(a::$T, ::DNE) = a @@ -80,7 +80,7 @@ end Base.:+(a::One, b::One) = extern(a) + extern(b) Base.:*(::One, ::One) = One() -for T in (:Thunk, :Any) +for T in (:AbstractThunk, :Any) @eval Base.:+(a::One, b::$T) = extern(a) + b @eval Base.:+(a::$T, b::One) = a + extern(b) @@ -89,12 +89,12 @@ for T in (:Thunk, :Any) end -Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b) -Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b) -for T in (:Any,) #This loop is redundant but for consistency... - @eval Base.:+(a::Thunk, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::Thunk) = a + extern(b) +Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b) +Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b) +for T in (:Any,) + @eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b + @eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b) - @eval Base.:*(a::Thunk, b::$T) = extern(a) * b - @eval Base.:*(a::$T, b::Thunk) = a * extern(b) + @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b + @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) end diff --git a/src/differentials.jl b/src/differentials.jl index 3f3e19d62..5ad2f8818 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing) Base.iterate(::One, ::Any) = nothing +##### +##### `AbstractThunk +##### +abstract type AbstractThunk <: AbstractDifferential end + +Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x)) + +@inline function Base.iterate(x::AbstractThunk) + externed = extern(x) + element, state = iterate(externed) + return element, (externed, state) +end + +@inline function Base.iterate(::AbstractThunk, (externed, state)) + element, new_state = iterate(externed, state) + return element, (externed, new_state) +end + ##### ##### `Thunk` ##### @@ -181,8 +199,9 @@ Base.iterate(::One, ::Any) = nothing Thunk(()->v) A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a differential. +`@thunk(v)` is a macro that expands into `Thunk(()->v)`. -Calling that thunk, calls the wrapped closure. +Calling a thunk, calls the wrapped closure. `extern`ing thunks applies recursively, it also externs the differial that the closure returns. If you do not want that, then simply call the thunk @@ -199,8 +218,24 @@ Thunk(var"##8#10"()) julia> t()() 3 ``` + +### When to `@thunk`? +When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk` +appropriately. +Propagation rule's that return multiple derivatives are not able to do all the computing themselves. + By `@thunk`ing the work required for each, they then compute only what is needed. + +#### So why not thunk everything? +`@thunk` creates a closure over the expression, which (effectively) creates a `struct` +with a field for each variable used in the expression, and call overloaded. + +Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being: +- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)` +- The expression being a constant +- The expression being itself a `thunk` +- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already) """ -struct Thunk{F} <: AbstractDifferential +struct Thunk{F} <: AbstractThunk f::F end @@ -208,22 +243,62 @@ macro thunk(body) return :(Thunk(() -> $(esc(body)))) end +# have to define this here after `@thunk` and `Thunk` is defined +Base.conj(x::AbstractThunk) = @thunk(conj(extern(x))) + + (x::Thunk)() = x.f() @inline extern(x::Thunk) = extern(x()) -Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x)) +Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") -@inline function Base.iterate(x::Thunk) - externed = extern(x) - element, state = iterate(externed) - return element, (externed, state) +""" + InplaceableThunk(val::Thunk, add!::Function) + +A wrapper for a `Thunk`, that allows it to define an inplace `add!` function, +which is used internally in `accumulate!(Δ, ::InplaceableThunk)`. + +`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val` +but it should do this more efficently than simply doing this directly. +(Otherwise one can just use a normal `Thunk`). + +Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; +and destroy its inplacability. +""" +struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk + val::T + add!::F end -@inline function Base.iterate(::Thunk, (externed, state)) - element, new_state = iterate(externed, state) - return element, (externed, new_state) +(x::InplaceableThunk)() = x.val() +@inline extern(x::InplaceableThunk) = extern(x.val) + +function Base.show(io::IO, x::InplaceableThunk) + println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") end -Base.conj(x::Thunk) = @thunk(conj(extern(x))) +# The real reason we have this: +accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) +store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. -Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") +""" + NO_FIELDS + +Constant for the reverse-mode derivative with respect to a structure that has no fields. +The most notable use for this is for the reverse-mode derivative with respect to the +function itself, when that function is not a closure. +""" +const NO_FIELDS = DNE() + +""" + refine_differential(𝒟::Type, der) + +Converts, if required, a differential object `der` +(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), +to another differential that is more suited for the domain given by the type 𝒟. +Often this will behave as the identity function on `der`. +""" +function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) + return wirtinger_primal(w) + wirtinger_conjugate(w) +end +refine_differential(::Any, der) = der # most of the time leave it alone. diff --git a/src/operations.jl b/src/operations.jl new file mode 100644 index 000000000..c60134e9f --- /dev/null +++ b/src/operations.jl @@ -0,0 +1,45 @@ +# TODO: This all needs a fair bit of rethinking + +""" + accumulate(Δ, ∂) + +Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's +various `AbstractDifferential` types. + +See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +""" +accumulate(Δ, ∂) = Δ .+ ∂ + +""" + accumulate!(Δ, ∂) + +Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place, +storing the result in `Δ`. + +Note: this function may not actually store the result in `Δ` if `Δ` is immutable, +so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case. + +This function is overloadable by using a [`InplaceThunk`](@ref). +See also: [`accumulate`](@ref), [`store!`](@ref). +""" +function accumulate!(Δ, ∂) + return materialize!(Δ, broadcastable(cast(Δ) + ∂)) +end + +accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂) + + + +""" + store!(Δ, ∂) + +Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. +potentially avoiding intermediate temporary allocations that might be +necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`) + +Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended +to be customizable for specific rules/input types. + +See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) +""" +store!(Δ, ∂) = materialize!(Δ, broadcastable(∂)) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 1caa0d987..a06820e64 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -14,17 +14,22 @@ methods for `frule` and `rrule`: function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) - return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), - Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...), - ...) + return Ω, (_, Δx₁, Δx₂, ...) -> ( + (∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), + (∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...), + ... + ) end function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) - return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), - Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...), - ...) + return Ω, (ΔΩ₁, ΔΩ₂, ...) -> ( + NO_FIELDS, + ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), + ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...), + ... + ) end If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are @@ -34,11 +39,16 @@ Constraints may also be explicitly be provided to override the `Number` constrai e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to `Number`. -Note that the result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This +At present this does not support defining for closures/functors. +Thus in reverse-mode, the first returned partial, +representing the derivative with respect to the function itself, is always `NO_FIELDS`. +And in forward-mode, the first input to the returned propagator is always ignored. + +The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This allows the primal result to be conveniently referenced (as `Ω`) within the derivative/setup expressions. -Note that the `@setup` argument can be elided if no setup code is need. In other +The `@setup` argument can be elided if no setup code is need. In other words: @scalar_rule(f(x₁, x₂, ...), @@ -59,6 +69,49 @@ For examples, see ChainRulesCore' `rules` directory. See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref) """ macro scalar_rule(call, maybe_setup, partials...) + call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( + call, maybe_setup, partials + ) + f = call.args[1] + + # An expression that when evaluated will return the type of the input domain. + # Multiple repetitions of this expression should optimize out. But if it does not then + # may need to move its definition into the body of the `rrule`/`frule` + 𝒟 = :(typeof(first(promote($(call.args[2:end]...))))) + + frule_expr = scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials) + rrule_expr = scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials) + + + ############################################################################ + # Final return: building the expression to insert in the place of this macro + code = quote + if !($f isa Type) && fieldcount(typeof($f)) > 0 + throw(ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))" + )) + end + + $(frule_expr) + $(rrule_expr) + end +end + + +""" + _normalize_scalarrules_macro_input(call, maybe_setup, partials) + +returns (in order) the correctly escaped: + - `call` with out any type constraints + - `setup_stmts`: the content of `@setup` or `nothing` if that is not provided, + - `inputs`: with all args having the constraints removed from call, or + defaulting to `Number` + - `partials`: which are all `Expr{:tuple,...}` +""" +function _normalize_scalarrules_macro_input(call, maybe_setup, partials) + ############################################################################ + # Setup: normalizing input form etc + if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup") setup_stmts = map(esc, maybe_setup.args[3:end]) else @@ -66,11 +119,12 @@ macro scalar_rule(call, maybe_setup, partials...) partials = (maybe_setup, partials...) end @assert Meta.isexpr(call, :call) - f = esc(call.args[1]) + # Annotate all arguments in the signature as scalars inputs = map(call.args[2:end]) do arg esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number)) end + # Remove annotations and escape names for the call for (i, arg) in enumerate(call.args) if Meta.isexpr(arg, :(::)) @@ -79,69 +133,164 @@ macro scalar_rule(call, maybe_setup, partials...) call.args[i] = esc(arg) end end - if all(Meta.isexpr(partial, :tuple) for partial in partials) - input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input - forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials] - reverse_rules = Any[] - for i in 1:length(inputs) - reverse_partials = [partial.args[i] for partial in partials] - push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...)) + + # For consistency in code that follows we make all partials tuple expressions + partials = map(partials) do partial + if Meta.isexpr(partial, :tuple) + partial + else + length(inputs) == 1 || error("Invalid use of `@scalar_rule`") + Expr(:tuple, partial) end + end + + return call, setup_stmts, inputs, partials +end + +function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials) + n_outputs = length(partials) + n_inputs = length(inputs) + + # Δs is the input to the propagator rule + # because this is push-forward there is one per input to the function + Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] + pushforward_returns = map(1:n_outputs) do output_i + ∂s = partials[output_i].args + propagation_expr(𝒟, Δs, ∂s) + end + if n_outputs > 1 + # For forward-mode we only return a tuple if output actually a tuple. + pushforward_returns = Expr(:tuple, pushforward_returns...) else - @assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials) - forward_rules = Any[rule_from_partials(inputs[1], partial) for partial in partials] - reverse_rules = Any[rule_from_partials(inputs[1], partials...)] + pushforward_returns = pushforward_returns[1] + end + + pushforward = quote + # _ is the input derivative w.r.t. function internals. since we do not + # allow closures/functors with @scalar_rule, it is always ignored + function $(propagator_name(f, :pushforward))(_, $(Δs...)) + $pushforward_returns + end end - forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...) - reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...) + return quote function ChainRulesCore.frule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) - return $(esc(:Ω)), $forward_rules + return $(esc(:Ω)), $pushforward end + end +end + +function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials) + n_outputs = length(partials) + n_inputs = length(inputs) + + # Δs is the input to the propagator rule + # because this is a pull-back there is one per output of function + Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs] + + # 1 partial derivative per input + pullback_returns = map(1:n_inputs) do input_i + ∂s = [partial.args[input_i] for partial in partials] + propagation_expr(𝒟, Δs, ∂s) + end + + pullback = quote + function $(propagator_name(f, :pullback))($(Δs...)) + return (NO_FIELDS, $(pullback_returns...)) + end + end + + return quote function ChainRulesCore.rrule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) - return $(esc(:Ω)), $reverse_rules + return $(esc(:Ω)), $pullback end end end -function rule_from_partials(input_arg, ∂s...) - wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s) +""" + propagation_expr(𝒟, Δs, ∂s) + + Returns the expression for the propagation of + the input gradient `Δs` though the partials `∂s`. + + 𝒟 is an expression that when evaluated returns the type-of the input domain. + For example if the derivative is being taken at the point `1` it returns `Int`. + if it is taken at `1+1im` it returns `Complex{Int}`. + At present it is ignored for non-Wirtinger derivatives. +""" +function propagation_expr(𝒟, Δs, ∂s) + wirtinger_indices = findall(∂s) do ex + Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger + end ∂s = map(esc, ∂s) - Δs = [Symbol(string(:Δ, i)) for i in 1:length(∂s)] - Δs_tuple = Expr(:tuple, Δs...) if isempty(wirtinger_indices) - ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] - return :(Rule($Δs_tuple -> +($(∂_mul_Δs...)))) + return standard_propagation_expr(Δs, ∂s) else - ∂_mul_Δs_primal = Any[] - ∂_mul_Δs_conjugate = Any[] - ∂_wirtinger_defs = Any[] - for i in 1:length(∂s) - if i in wirtinger_indices - Δi = Δs[i] - ∂i = Symbol(string(:∂, i)) - push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) - ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) - ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) - ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) - ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) - push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) - push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) - else - ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) - push!(∂_mul_Δs_primal, ∂_mul_Δ) - push!(∂_mul_Δs_conjugate, ∂_mul_Δ) - end - end - primal_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_primal...)))) - conjugate_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_conjugate...)))) - return quote - $(∂_wirtinger_defs...) - AbstractRule(typeof($input_arg), $primal_rule, $conjugate_rule) + return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) + end +end + +function standard_propagation_expr(Δs, ∂s) + # This is basically Δs ⋅ ∂s + + # Notice: the thunking of `∂s[i] (potentially) saves us some computation + # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon + # as the pullback is evaluated + ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] + return :(+($(∂_mul_Δs...))) +end + +function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) + ∂_mul_Δs_primal = Any[] + ∂_mul_Δs_conjugate = Any[] + ∂_wirtinger_defs = Any[] + for i in 1:length(∂s) + if i in wirtinger_indices + Δi = Δs[i] + ∂i = Symbol(string(:∂, i)) + push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) + ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) + ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) + ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) + ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) + push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) + push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) + else + ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) + push!(∂_mul_Δs_primal, ∂_mul_Δ) + push!(∂_mul_Δs_conjugate, ∂_mul_Δ) end end + primal_sum = :(+($(∂_mul_Δs_primal...))) + conjugate_sum = :(+($(∂_mul_Δs_conjugate...))) + return quote # This will be a block, so will have value equal to last statement + $(∂_wirtinger_defs...) + w = Wirtinger($primal_sum, $conjugate_sum) + refine_differential($𝒟, w) + end end + +""" + propagator_name(f, propname) + +Determines a reasonable name for the propagator function. +The name doesn't really matter too much as it is a local function to be returned +by `frule` or `rrule`, but a good name make debugging easier. +`f` should be some form of AST representation of the actual function, +`propname` should be either `:pullback` or `:pushforward` + +This is able to deal with fairly complex expressions for `f`: + + julia> propagator_name(:bar, :pushforward) + :bar_pushforward + + julia> propagator_name(esc(:(Base.Random.foo)), :pullback) + :foo_pullback +""" +propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname) +propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) +propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) diff --git a/src/rule_types.jl b/src/rule_types.jl deleted file mode 100644 index 158838883..000000000 --- a/src/rule_types.jl +++ /dev/null @@ -1,210 +0,0 @@ -""" -Subtypes of `AbstractRule` are types which represent the primitive derivative -propagation "rules" that can be composed to implement forward- and reverse-mode -automatic differentiation. - -More specifically, a `rule::AbstractRule` is a callable Julia object generally -obtained via calling [`frule`](@ref) or [`rrule`](@ref). Such rules accept -differential values as input, evaluate the chain rule using internally stored/ -computed partial derivatives to produce a single differential value, then -return that calculated differential value. - -For example: - -``` -julia> using ChainRulesCore: frule, rrule, AbstractRule - -julia> x, y = rand(2); - -julia> h, dh = frule(hypot, x, y); - -julia> h == hypot(x, y) -true - -julia> isa(dh, AbstractRule) -true - -julia> Δx, Δy = rand(2); - -julia> dh(Δx, Δy) == ((x / h) * Δx + (y / h) * Δy) -true - -julia> h, (dx, dy) = rrule(hypot, x, y); - -julia> h == hypot(x, y) -true - -julia> isa(dx, AbstractRule) && isa(dy, AbstractRule) -true - -julia> Δh = rand(); - -julia> dx(Δh) == (x / h) * Δh -true - -julia> dy(Δh) == (y / h) * Δh -true -``` - -See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [`WirtingerRule`](@ref) -""" -abstract type AbstractRule end - -# this ensures that consumers don't have to special-case rule destructuring -Base.iterate(rule::AbstractRule) = (rule, nothing) -Base.iterate(::AbstractRule, ::Any) = nothing - -# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple -# in order to get the `i`th rule (assuming it's 1) -Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError()) - -""" - accumulate(Δ, rule::AbstractRule, args...) - -Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore' -various `AbstractDifferential` types. - -This method intended to be customizable for specific rules/input types. For -example, here is pseudocode to overload `accumulate` w.r.t. a specific forward -differentiation rule for a given function `f`: - -``` -df(x) = # forward differentiation primitive implementation - -frule(::typeof(f), x) = (f(x), Rule(df)) - -accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation -``` - -See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -accumulate(Δ, rule::AbstractRule, args...) = Δ + rule(args...) - -""" - accumulate!(Δ, rule::AbstractRule, args...) - -Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place, -storing the result in `Δ`. - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -function accumulate!(Δ, rule::AbstractRule, args...) - return materialize!(Δ, broadcastable(cast(Δ) + rule(args...))) -end - -accumulate!(Δ::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...) - -""" - store!(Δ, rule::AbstractRule, args...) - -Compute `rule(args...)` and store the result in `Δ`, potentially avoiding -intermediate temporary allocations that might be necessary for alternative -approaches (e.g. `copyto!(Δ, extern(rule(args...)))`) - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended -to be customizable for specific rules/input types. - -See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) -""" -store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...))) - -##### -##### `Rule` -##### - - -""" - Rule(propation_function[, updating_function]) - -Return a `Rule` that wraps the given `propation_function`. It is assumed that -`propation_function` is a callable object whose arguments are differential -values, and whose output is a single differential value calculated by applying -internally stored/computed partial derivatives to the input differential -values. - -If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)` -and to store the result of the propagation function applied to the arguments `xs` into -`Δ` in-place, returning `Δ`. - -For example: - -``` -frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy) - -rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) -``` - -See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref) -""" -struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule - f::F - u::U -end - -# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some -# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}` -Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing) - -(rule::Rule)(args...) = rule.f(args...) - -Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))") -Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))") - -# Specialized accumulation -# TODO: Does this need to be overdubbed in the rule context? -accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) - -##### -##### `DNERule` -##### - -""" - DNERule(args...) - -Construct a `DNERule` object, which is an `AbstractRule` that signifies that the -current function is not differentiable with respect to a particular parameter. -**DNE** is an abbreviation for Does Not Exist. -""" -struct DNERule <: AbstractRule end - -DNERule(args...) = DNE() - -##### -##### `WirtingerRule` -##### - -""" - WirtingerRule(primal::AbstractRule, conjugate::AbstractRule) - -Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of -an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate -derivative ``∂/∂x̅``. If the domain `𝒟` of the function might be real, consider -calling `AbstractRule(𝒟, primal, conjugate)` instead, to make use of a more -efficient representation wherever possible. -""" -struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule - primal::P - conjugate::C -end - -function (rule::WirtingerRule)(args...) - return Wirtinger(rule.primal(args...), rule.conjugate(args...)) -end - -""" - AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) - -Return a `Rule` evaluating to `primal(Δ) + conjugate(Δ)` if `𝒟 <: Real`, -otherwise return `WirtingerRule(P, C)`. -""" -function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) - if 𝒟 <: Real || eltype(𝒟) <: Real - return Rule((args...) -> (primal(args...) + conjugate(args...))) - else - return WirtingerRule(primal, conjugate) - end -end diff --git a/src/rules.jl b/src/rules.jl index 6a145fad7..371013ced 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -33,7 +33,7 @@ my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...) function Cassette.execute(::MyChainRuleCtx, ::typeof(frule), f, x::Number) r = frule(f, x) if isa(r, Nothing) - fx, df = (f(x), Rule(Δx -> ForwardDiff.derivative(f, x) * Δx)) + fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx) else fx, df = r end @@ -48,16 +48,12 @@ end Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)` as `Ω`, return the tuple: - (Ω, (rule_for_ΔΩ₁::AbstractRule, rule_for_ΔΩ₂::AbstractRule, ...)) + (Ω, (ṡelf, ẋ₁, ẋ₂, ...) -> Ω̇₁, Ω̇₂, ...) -where each returned propagation rule `rule_for_ΔΩᵢ` can be invoked as +The second return value is the propagation rule, or the pushforward. +It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`) +and `ṡelf` the internal values of the function (for closures). - rule_for_ΔΩᵢ(Δx₁, Δx₂, ...) - -to yield `Ωᵢ`'s corresponding differential `ΔΩᵢ`. To illustrate, if all involved -values are real-valued scalars, this differential can be written as: - - ΔΩᵢ = ∂Ωᵢ_∂x₁ * Δx₁ + ∂Ωᵢ_∂x₂ * Δx₂ + ... If no method matching `frule(f, xs...)` has been defined, then return `nothing`. @@ -68,12 +64,12 @@ unary input, unary output scalar function: ``` julia> x = rand(); -julia> sinx, dsin = frule(sin, x); +julia> sinx, sin_pushforward = frule(sin, x); julia> sinx == sin(x) true -julia> dsin(1) == cos(x) +julia> sin_pushforward(NamedTuple(), 1) == cos(x) true ``` @@ -82,19 +78,16 @@ unary input, binary output scalar function: ``` julia> x = rand(); -julia> sincosx, (dsin, dcos) = frule(sincos, x); +julia> sincosx, sincos_pushforward = frule(sincos, x); julia> sincosx == sincos(x) true -julia> dsin(1) == cos(x) -true - -julia> dcos(1) == -sin(x) +julia> sincos_pushforward(NamedTuple(), 1) == (cos(x), -sin(x)) true ``` -See also: [`rrule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref) +See also: [`rrule`](@ref), [`@scalar_rule`](@ref) """ frule(::Any, ::Vararg{Any}; kwargs...) = nothing @@ -104,16 +97,11 @@ frule(::Any, ::Vararg{Any}; kwargs...) = nothing Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)` as `Ω`, return the tuple: - (Ω, (rule_for_Δx₁::AbstractRule, rule_for_Δx₂::AbstractRule, ...)) - -where each returned propagation rule `rule_for_Δxᵢ` can be invoked as + (Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...)) - rule_for_Δxᵢ(ΔΩ₁, ΔΩ₂, ...) - -to yield `xᵢ`'s corresponding differential `Δxᵢ`. To illustrate, if all involved -values are real-valued scalars, this differential can be written as: - - Δxᵢ = ∂Ω₁_∂xᵢ * ΔΩ₁ + ∂Ω₂_∂xᵢ * ΔΩ₂ + ... +Where the second return value is the the propagation rule or pullback. +It takes in differentials corresponding to the outputs (`x̄₁, x̄₂, ...`), +and `s̄elf`, the internal values of the function itself (for closures) If no method matching `rrule(f, xs...)` has been defined, then return `nothing`. @@ -124,12 +112,12 @@ unary input, unary output scalar function: ``` julia> x = rand(); -julia> sinx, dx = rrule(sin, x); +julia> sinx, sin_pullback = rrule(sin, x); julia> sinx == sin(x) true -julia> dx(1) == cos(x) +julia> sin_pullback(1) == (NO_FIELDS, cos(x)) true ``` @@ -138,18 +126,15 @@ binary input, unary output scalar function: ``` julia> x, y = rand(2); -julia> hypotxy, (dx, dy) = rrule(hypot, x, y); +julia> hypotxy, hypot_pullback = rrule(hypot, x, y); julia> hypotxy == hypot(x, y) true -julia> dx(1) == (x / hypot(x, y)) -true - -julia> dy(1) == (y / hypot(x, y)) +julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y))) true ``` -See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref) +See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ rrule(::Any, ::Vararg{Any}; kwargs...) = nothing diff --git a/test/differentials.jl b/test/differentials.jl index 557c125d8..570b09d88 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -79,4 +79,20 @@ ] @test isempty(ambig_methods) end + + + @testset "Refine Differential" begin + @test refine_differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2) + @test refine_differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2) + + @test refine_differential(typeof(1.2), Wirtinger(2,2)) == 4 + @test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4 + + # For most differentials, in most domains, this does nothing + for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) + for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2])) + @test refine_differential(𝒟, der) === der + end + end + end end diff --git a/test/rule_types.jl b/test/rule_types.jl deleted file mode 100644 index 79aa362cc..000000000 --- a/test/rule_types.jl +++ /dev/null @@ -1,61 +0,0 @@ - -@testset "rule types" begin - @testset "iterating and indexing rules" begin - _, rule = frule(dummy_identity, 1) - i = 0 - for r in rule - @test r === rule - i += 1 - end - @test i == 1 # rules only iterate once, yielding themselves - @test rule[1] == rule - @test_throws BoundsError rule[2] - end - - @testset "Rule" begin - @testset "show" begin - @test occursin(r"^Rule\(.*foo.*\)$", repr(Rule(function foo() 1 end))) - @test occursin(r"^Rule\(.*identity.*\)$", repr(Rule(identity))) - - @test occursin(r"^Rule\(.*identity.*\,.*\+.*\)$", repr(Rule(identity, +))) - end - end - - @testset "WirtingerRule" begin - myabs2(x) = abs2(x) - - function ChainRulesCore.frule(::typeof(myabs2), x) - return abs2(x), AbstractRule( - typeof(x), - Rule(Δx -> Δx * x'), - Rule(Δx -> Δx * x) - ) - end - - # real input - x = rand(Float64) - f, _df = @inferred frule(myabs2, x) - @test f === x^2 - - df = @inferred _df(One()) - @test df === x + x - - - Δ = rand(Complex{Int64}) - df = @inferred _df(Δ) - @test df === Δ * (x + x) - - - # complex input - z = rand(Complex{Float64}) - f, _df = @inferred frule(myabs2, z) - @test f === abs2(z) - - df = @inferred _df(One()) - @test df === Wirtinger(z', z) - - Δ = rand(Complex{Int64}) - df = @inferred _df(Δ) - @test df === Wirtinger(Δ * z', Δ * z) - end -end diff --git a/test/rules.jl b/test/rules.jl index 7295da7a3..e23680326 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -6,7 +6,6 @@ cool(x, y) = x + y + 1 # a rule we define so we can test rules dummy_identity(x) = x - @scalar_rule(dummy_identity(x), One()) ####### @@ -19,6 +18,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrule(cool, 1) === nothing @test rrule(cool, 1; iscool=true) === nothing + # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @test hasmethod(rrule, Tuple{typeof(cool),Number}) ChainRulesCore.@scalar_rule(Main.cool(x::String), "wow such dfdx") @@ -29,10 +29,123 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) Tuple{typeof(rrule),typeof(cool),String}]) @test cool_methods == only_methods - frx, fr = frule(cool, 1) + frx, cool_pushforward = frule(cool, 1) @test frx == 2 - @test fr(1) == 1 - rrx, rr = rrule(cool, 1) + @test cool_pushforward(NamedTuple(), 1) == 1 + rrx, cool_pullback = rrule(cool, 1) + self, rr1 = cool_pullback(1) + @test self == NO_FIELDS @test rrx == 2 - @test rr(1) == 1 + @test rr1 == 1 +end + + +@testset "Basic Wirtinger scalar_rule" begin + myabs2(x) = abs2(x) + @scalar_rule(myabs2(x), Wirtinger(x', x)) + + @testset "real input" begin + # even though our rule was define in terms of Wirtinger, + # pushforward result will be real as real (even if seed is Compex) + + x = rand(Float64) + f, myabs2_pushforward = frule(myabs2, x) + @test f === x^2 + + Δ = One() + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === x + x + + Δ = rand(Complex{Int64}) + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === Δ * (x + x) + end + + @testset "complex input" begin + z = rand(Complex{Float64}) + f, myabs2_pushforward = frule(myabs2, z) + @test f === abs2(z) + + df = @inferred myabs2_pushforward(NamedTuple(), One()) + @test df === Wirtinger(z', z) + + Δ = rand(Complex{Int64}) + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === Wirtinger(Δ * z', Δ * z) + end +end + + +@testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin + # This is based on SimeonSchaub excellent example: + # https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 + + # This is much more complex than the previous case + # as it has many different types + # depending on input, and the output types do not always agree + + abs_to_pow(x, p) = abs(x)^p + @scalar_rule( + abs_to_pow(x::Real, p), + ( + p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), + Ω * log(abs(x)) + ) + ) + + @scalar_rule( + abs_to_pow(x::Complex, p), + @setup(u = abs(x)), + ( + p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), + Ω * log(abs(x)) + ) + ) + + + f = abs_to_pow + @testset "f($x, $p)" for (x, p) in Iterators.product( + (2, 3.4, -2.1, -10+0im, 2.3-2im), + (0, 1, 2, 4.3, -2.1, 1+.2im) + ) + expected_type_df_dx = + if iszero(p) + Zero + elseif typeof(x) <: Complex + Wirtinger + elseif typeof(p) <: Complex + Complex + else + Real + end + + expected_type_df_dp = + if typeof(p) <: Real + Real + else + Complex + end + + + res = frule(f, x, p) + @test res !== nothing # Check the rule was defined + fx, f_pushforward = res + df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) + + df_dx::Thunk = df(One(), Zero()) + df_dp::Thunk = df(Zero(), One()) + @test fx == f(x, p) # Check we still get the normal value, right + @test df_dx() isa expected_type_df_dx + @test df_dp() isa expected_type_df_dp + + + res = rrule(f, x, p) + @test res !== nothing # Check the rule was defined + fx, f_pullback = res + dself, df_dx, df_dp = f_pullback(One()) + @test fx == f(x, p) # Check we still get the normal value, right + @test dself == NO_FIELDS + @test df_dx() isa expected_type_df_dx + @test df_dp() isa expected_type_df_dp + end end diff --git a/test/runtests.jl b/test/runtests.jl index 4e10c970b..5f8552935 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,12 +4,10 @@ using ChainRulesCore using LinearAlgebra: Diagonal using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, - Zero, One, Casted, cast, - DNE, Thunk, Casted, DNERule, WirtingerRule + Zero, One, Casted, cast, DNE, Thunk using Base.Broadcast: broadcastable @testset "ChainRulesCore" begin include("differentials.jl") include("rules.jl") - include("rule_types.jl") end