From 0fabda9532a1a4e5724945275b18120112a70a7f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 28 Aug 2019 14:23:28 +0100 Subject: [PATCH 01/23] =WIP Derivative wrt function =Make frule wrt self and rrule wrt self different [WIP --- Project.toml | 2 +- src/ChainRulesCore.jl | 1 + src/rule_definition_tools.jl | 28 +++- src/rules.jl | 239 +++++++++++++++++++++++++++++++++++ 4 files changed, 263 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index a81fc0d4d..f7686dbe4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.2.1-DEV" +version = "v0.3.0" [compat] julia = "^1.0" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 2fd3302c9..f5baddab9 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,6 +4,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad export AbstractRule, Rule, frule, rrule export @scalar_rule, @thunk export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule +export NO_FIELDS_RULE, ZERO_RULE include("differentials.jl") include("differential_arithmetic.jl") diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 1caa0d987..af233406f 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -14,7 +14,8 @@ methods for `frule` and `rrule`: function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) - return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), + return Ω, (ZERO_RULE, + Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...), ...) end @@ -22,7 +23,8 @@ methods for `frule` and `rrule`: function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) - return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), + return Ω, (NO_FIELDS_RULE, + Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...), ...) end @@ -34,11 +36,16 @@ Constraints may also be explicitly be provided to override the `Number` constrai e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to `Number`. -Note that the result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This +At present this does not support defining rules for closures/functors. +This the first returned rule, representing the derivative with respect to the +function itself, is always the `NO_FIELDS_RULE` (reverse-mode), +or `ZERO_RULE` (forward-mode). + +The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This allows the primal result to be conveniently referenced (as `Ω`) within the derivative/setup expressions. -Note that the `@setup` argument can be elided if no setup code is need. In other +The `@setup` argument can be elided if no setup code is need. In other words: @scalar_rule(f(x₁, x₂, ...), @@ -92,9 +99,18 @@ macro scalar_rule(call, maybe_setup, partials...) forward_rules = Any[rule_from_partials(inputs[1], partial) for partial in partials] reverse_rules = Any[rule_from_partials(inputs[1], partials...)] end - forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...) - reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...) + + # First pseudo-partial is derivative WRT function itself. Since this macro does not + # support closures, it is just the empty NamedTuple + forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...) + reverse_rules = Expr(:tuple, NO_FIELDS_RULE, reverse_rules...) return quote + if fieldcount(typeof($f)) > 0 + throw(ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $f)" + )) + end + function ChainRulesCore.frule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) diff --git a/src/rules.jl b/src/rules.jl index 6a145fad7..5537fa3f8 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,3 +1,242 @@ +""" +Subtypes of `AbstractRule` are types which represent the primitive derivative +propagation "rules" that can be composed to implement forward- and reverse-mode +automatic differentiation. + +More specifically, a `rule::AbstractRule` is a callable Julia object generally +obtained via calling [`frule`](@ref) or [`rrule`](@ref). Such rules accept +differential values as input, evaluate the chain rule using internally stored/ +computed partial derivatives to produce a single differential value, then +return that calculated differential value. + +For example: + +```jldoctest +julia> using ChainRulesCore: frule, rrule, AbstractRule + +julia> x, y = rand(2); + +julia> h, dh = frule(hypot, x, y); + +julia> h == hypot(x, y) +true + +julia> isa(dh, AbstractRule) +true + +julia> Δx, Δy = rand(2); + +julia> dh(Δx, Δy) == ((x / h) * Δx + (y / h) * Δy) +true + +julia> h, (dx, dy) = rrule(hypot, x, y); + +julia> h == hypot(x, y) +true + +julia> isa(dx, AbstractRule) && isa(dy, AbstractRule) +true + +julia> Δh = rand(); + +julia> dx(Δh) == (x / h) * Δh +true + +julia> dy(Δh) == (y / h) * Δh +true +``` + +See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [`WirtingerRule`](@ref) +""" +abstract type AbstractRule end + +# this ensures that consumers don't have to special-case rule destructuring +Base.iterate(rule::AbstractRule) = (rule, nothing) +Base.iterate(::AbstractRule, ::Any) = nothing + +# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple +# in order to get the `i`th rule (assuming it's 1) +Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError()) + +""" + accumulate(Δ, rule::AbstractRule, args...) + +Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore' +various `AbstractDifferential` types. + +This method intended to be customizable for specific rules/input types. For +example, here is pseudocode to overload `accumulate` w.r.t. a specific forward +differentiation rule for a given function `f`: + +``` +df(x) = # forward differentiation primitive implementation + +frule(::typeof(f), x) = (f(x), Rule(df)) + +accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation +``` + +See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +""" +accumulate(Δ, rule::AbstractRule, args...) = add(Δ, rule(args...)) + +""" + accumulate!(Δ, rule::AbstractRule, args...) + +Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place, +storing the result in `Δ`. + +Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. + +See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +""" +function accumulate!(Δ, rule::AbstractRule, args...) + return materialize!(Δ, broadcastable(add(cast(Δ), rule(args...)))) +end + +accumulate!(Δ::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...) + +""" + store!(Δ, rule::AbstractRule, args...) + +Compute `rule(args...)` and store the result in `Δ`, potentially avoiding +intermediate temporary allocations that might be necessary for alternative +approaches (e.g. `copyto!(Δ, extern(rule(args...)))`) + +Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. + +Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended +to be customizable for specific rules/input types. + +See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) +""" +store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...))) + +##### +##### `Rule` +##### + +Cassette.@context RuleContext + +const RULE_CONTEXT = Cassette.disablehooks(RuleContext()) + +Cassette.overdub(::RuleContext, ::typeof(+), a, b) = add(a, b) +Cassette.overdub(::RuleContext, ::typeof(*), a, b) = mul(a, b) + +Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b) +Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b) + +""" + Rule(propation_function[, updating_function]) + +Return a `Rule` that wraps the given `propation_function`. It is assumed that +`propation_function` is a callable object whose arguments are differential +values, and whose output is a single differential value calculated by applying +internally stored/computed partial derivatives to the input differential +values. + +If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)` +and to store the result of the propagation function applied to the arguments `xs` into +`Δ` in-place, returning `Δ`. + +For example: + +``` +frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy) + +rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) +``` + +See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref) +""" +struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule + f::F + u::U +end + +# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some +# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}` +Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing) + +(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...) + +# Specialized accumulation +# TODO: Does this need to be overdubbed in the rule context? +accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) + + +""" + NO_FIELDS_RULE + +Constant for the rule for the derivative with respect to structure that has no fields. +The most notable use for this is for the reverse-mode derivative with respect to the +function itself, when that function is not a closure. +The rule returns an empty `NamedTuple` for all inputs. +""" +const NO_FIELDS_RULE = Rule((args...)->NamedTuple()) + +""" + ZERO_RULE + +This is a rule that returns `Zero()` regardless of input. +The most notable use for this is for the forward-mode derivative with respect to the +function itself, when that function is not a closure. +""" +const ZERO_RULE = Rule((args...)->Zero()) + + + +##### +##### `DNERule` +##### + +""" + DNERule(args...) + +Construct a `DNERule` object, which is an `AbstractRule` that signifies that the +current function is not differentiable with respect to a particular parameter. +**DNE** is an abbreviation for Does Not Exist. +""" +struct DNERule <: AbstractRule end + +DNERule(args...) = DNE() + +##### +##### `WirtingerRule` +##### + +""" + WirtingerRule(primal::AbstractRule, conjugate::AbstractRule) + +Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of +an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate +derivative ``∂/∂x̅``. If the domain `𝒟` of the function might be real, consider +calling `AbstractRule(𝒟, primal, conjugate)` instead, to make use of a more +efficient representation wherever possible. +""" +struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule + primal::P + conjugate::C +end + +function (rule::WirtingerRule)(args...) + return Wirtinger(rule.primal(args...), rule.conjugate(args...)) +end + +""" + AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) + +Return a `Rule` evaluating to `primal(Δ) + conjugate(Δ)` if `𝒟 <: Real`, +otherwise return `WirtingerRule(P, C)`. +""" +function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) + if 𝒟 <: Real || eltype(𝒟) <: Real + return Rule((args...) -> add(primal(args...), conjugate(args...))) + else + return WirtingerRule(primal, conjugate) + end +end + ##### ##### `frule`/`rrule` ##### From 7a72d84ad13c034b8b5b82e423557a0fd8cc289a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 28 Aug 2019 18:32:46 +0100 Subject: [PATCH 02/23] make rules render better --- src/ChainRulesCore.jl | 3 ++- src/rules.jl | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f5baddab9..768a0d274 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -3,7 +3,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad export AbstractRule, Rule, frule, rrule export @scalar_rule, @thunk -export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule +export extern, cast, store! +export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule export NO_FIELDS_RULE, ZERO_RULE include("differentials.jl") diff --git a/src/rules.jl b/src/rules.jl index 5537fa3f8..fbd6c8d07 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -160,6 +160,9 @@ Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing) (rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...) +Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))") +Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))") + # Specialized accumulation # TODO: Does this need to be overdubbed in the rule context? accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) @@ -173,7 +176,7 @@ The most notable use for this is for the reverse-mode derivative with respect to function itself, when that function is not a closure. The rule returns an empty `NamedTuple` for all inputs. """ -const NO_FIELDS_RULE = Rule((args...)->NamedTuple()) +const NO_FIELDS_RULE = Rule(function no_fields(args...) NamedTuple() end) """ ZERO_RULE @@ -182,7 +185,7 @@ This is a rule that returns `Zero()` regardless of input. The most notable use for this is for the forward-mode derivative with respect to the function itself, when that function is not a closure. """ -const ZERO_RULE = Rule((args...)->Zero()) +const ZERO_RULE = Rule(function always_zero(args...) Zero() end) From 7789a67013b1318f7617a424bb53f9394d63d249 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 30 Aug 2019 16:49:29 +0100 Subject: [PATCH 03/23] fix version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f7686dbe4..938852959 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "v0.3.0" +version = "0.3.0" [compat] julia = "^1.0" From 0b232a4ecce705414c56dc1f11029165041b3992 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 30 Aug 2019 19:55:23 +0100 Subject: [PATCH 04/23] Remove duplicate of rule_types from bbad rebase WIP only one pullback with many partials (re #38) --- src/ChainRulesCore.jl | 4 +- src/differentials.jl | 8 ++ src/rule_definition_tools.jl | 85 +++++++++--- src/rule_types.jl | 15 ++- src/rules.jl | 242 ----------------------------------- test/rule_types.jl | 4 +- test/runtests.jl | 4 +- 7 files changed, 93 insertions(+), 269 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 768a0d274..5ede7da7d 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -5,11 +5,11 @@ export AbstractRule, Rule, frule, rrule export @scalar_rule, @thunk export extern, cast, store! export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule -export NO_FIELDS_RULE, ZERO_RULE +export NO_FIELDS include("differentials.jl") include("differential_arithmetic.jl") include("rule_types.jl") include("rules.jl") -include("rule_definition_tools.jl") +#include("rule_definition_tools.jl") end # module diff --git a/src/differentials.jl b/src/differentials.jl index 3f3e19d62..9f74e2c0a 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -227,3 +227,11 @@ end Base.conj(x::Thunk) = @thunk(conj(extern(x))) Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") + +""" + NO_FIELDS +Constant for the reverse-mode derivative with respect to a structure that has no fields. +The most notable use for this is for the reverse-mode derivative with respect to the +function itself, when that function is not a closure. +""" +const NO_FIELDS = DNE() diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index af233406f..d68eb0410 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -14,19 +14,22 @@ methods for `frule` and `rrule`: function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) - return Ω, (ZERO_RULE, - Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), - Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...), - ...) + return Ω, (_, Δx₁, Δx₂, ...) -> ( + (∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), + (∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...), + ... + ) end function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...) \$(statement₁, statement₂, ...) - return Ω, (NO_FIELDS_RULE, - Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), - Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...), - ...) + return Ω, (ΔΩ₁, ΔΩ₂, ...) -> ( + NO_FIELDS, + ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...), + ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...), + ... + ) end If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are @@ -36,10 +39,10 @@ Constraints may also be explicitly be provided to override the `Number` constrai e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to `Number`. -At present this does not support defining rules for closures/functors. -This the first returned rule, representing the derivative with respect to the -function itself, is always the `NO_FIELDS_RULE` (reverse-mode), -or `ZERO_RULE` (forward-mode). +At present this does not support defining for closures/functors. +Thus in reverse-mode, the first returned partial, +representing the derivative with respect to the function itself, is always `NO_FIELDS`. +And in forwards-mode, the first input to the returned propergator is always ignored. The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This allows the primal result to be conveniently referenced (as `Ω`) within the @@ -86,11 +89,54 @@ macro scalar_rule(call, maybe_setup, partials...) call.args[i] = esc(arg) end end - if all(Meta.isexpr(partial, :tuple) for partial in partials) + + partials = map(partials) do partial + if Meta.isexpr(partial, :tuple) + partial + else + @assert length(inputs) == 1 + Expr(:tuple, partial) + end + end + @show partials + + ############################################################ + # Make pullback + #(TODO: move to own function) + # TODO: Wirtinger + + Δs = [Symbol(string(:Δ, i)) for i in 1:length(partials)] + pullback_returns = map(eachindex(inputs)) do input_i + ∂s = [partials.args[input_i] for partial in partials] + ∂s = map(esc, ∂s) + + # Notice: the thunking of `∂s[i] (potentially) saves us some computation + # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon + # as the pullback is evaluated + ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] + :(+($(∂_mul_Δs...))) + else + + pullback = quote + function $(Symbol(nameof(f), :_pullback))($(Δs...)) + return (ChainRulesCore.NO_FIELDS, $(pullback_returns...)) + end + end + + ######################################## + quote + function ChainRulesCore.rrule(::typeof($f), $(inputs...)) + $(esc(:Ω)) = $call + $(setup_stmts...) + return $(esc(:Ω)), $esc(pullback) + end + end +end +#== + if !all(Meta.isexpr(partial, :tuple) for partial in partials) input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials] - reverse_rules = Any[] - for i in 1:length(inputs) + reverse_rules = map(1:length(inputs) do i reverse_partials = [partial.args[i] for partial in partials] push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...)) end @@ -103,7 +149,7 @@ macro scalar_rule(call, maybe_setup, partials...) # First pseudo-partial is derivative WRT function itself. Since this macro does not # support closures, it is just the empty NamedTuple forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...) - reverse_rules = Expr(:tuple, NO_FIELDS_RULE, reverse_rules...) + reverse_rules = Expr(:tuple, NO_FIELDS, reverse_rules...) return quote if fieldcount(typeof($f)) > 0 throw(ArgumentError( @@ -123,7 +169,13 @@ macro scalar_rule(call, maybe_setup, partials...) end end end +==# + +@macroexpand(@scalar_rule(one(x), Zero())) + + +#== function rule_from_partials(input_arg, ∂s...) wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s) ∂s = map(esc, ∂s) @@ -161,3 +213,4 @@ function rule_from_partials(input_arg, ∂s...) end end end +==# diff --git a/src/rule_types.jl b/src/rule_types.jl index 158838883..5364f3751 100644 --- a/src/rule_types.jl +++ b/src/rule_types.jl @@ -51,12 +51,15 @@ See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [ abstract type AbstractRule end # this ensures that consumers don't have to special-case rule destructuring -Base.iterate(rule::AbstractRule) = (rule, nothing) +Base.iterate(rule::AbstractRule) = (@warn "iterating rules is going away"; (rule, nothing)) Base.iterate(::AbstractRule, ::Any) = nothing # This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple # in order to get the `i`th rule (assuming it's 1) -Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError()) +function Base.getindex(rule::AbstractRule, i::Integer) + @warn "iterating rules is going away" + return i == 1 ? rule : throw(BoundsError()) +end """ accumulate(Δ, rule::AbstractRule, args...) @@ -78,7 +81,7 @@ accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementa See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) """ -accumulate(Δ, rule::AbstractRule, args...) = Δ + rule(args...) +accumulate(Δ, rule, args...) = Δ + rule(args...) """ accumulate!(Δ, rule::AbstractRule, args...) @@ -90,11 +93,11 @@ Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) """ -function accumulate!(Δ, rule::AbstractRule, args...) +function accumulate!(Δ, rule, args...) return materialize!(Δ, broadcastable(cast(Δ) + rule(args...))) end -accumulate!(Δ::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...) +accumulate!(Δ::Number, rule, args...) = accumulate(Δ, rule, args...) """ store!(Δ, rule::AbstractRule, args...) @@ -110,7 +113,7 @@ to be customizable for specific rules/input types. See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) """ -store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...))) +store!(Δ, rule, args...) = materialize!(Δ, broadcastable(rule(args...))) ##### ##### `Rule` diff --git a/src/rules.jl b/src/rules.jl index fbd6c8d07..6a145fad7 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -1,245 +1,3 @@ -""" -Subtypes of `AbstractRule` are types which represent the primitive derivative -propagation "rules" that can be composed to implement forward- and reverse-mode -automatic differentiation. - -More specifically, a `rule::AbstractRule` is a callable Julia object generally -obtained via calling [`frule`](@ref) or [`rrule`](@ref). Such rules accept -differential values as input, evaluate the chain rule using internally stored/ -computed partial derivatives to produce a single differential value, then -return that calculated differential value. - -For example: - -```jldoctest -julia> using ChainRulesCore: frule, rrule, AbstractRule - -julia> x, y = rand(2); - -julia> h, dh = frule(hypot, x, y); - -julia> h == hypot(x, y) -true - -julia> isa(dh, AbstractRule) -true - -julia> Δx, Δy = rand(2); - -julia> dh(Δx, Δy) == ((x / h) * Δx + (y / h) * Δy) -true - -julia> h, (dx, dy) = rrule(hypot, x, y); - -julia> h == hypot(x, y) -true - -julia> isa(dx, AbstractRule) && isa(dy, AbstractRule) -true - -julia> Δh = rand(); - -julia> dx(Δh) == (x / h) * Δh -true - -julia> dy(Δh) == (y / h) * Δh -true -``` - -See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [`WirtingerRule`](@ref) -""" -abstract type AbstractRule end - -# this ensures that consumers don't have to special-case rule destructuring -Base.iterate(rule::AbstractRule) = (rule, nothing) -Base.iterate(::AbstractRule, ::Any) = nothing - -# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple -# in order to get the `i`th rule (assuming it's 1) -Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError()) - -""" - accumulate(Δ, rule::AbstractRule, args...) - -Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore' -various `AbstractDifferential` types. - -This method intended to be customizable for specific rules/input types. For -example, here is pseudocode to overload `accumulate` w.r.t. a specific forward -differentiation rule for a given function `f`: - -``` -df(x) = # forward differentiation primitive implementation - -frule(::typeof(f), x) = (f(x), Rule(df)) - -accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation -``` - -See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -accumulate(Δ, rule::AbstractRule, args...) = add(Δ, rule(args...)) - -""" - accumulate!(Δ, rule::AbstractRule, args...) - -Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place, -storing the result in `Δ`. - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -function accumulate!(Δ, rule::AbstractRule, args...) - return materialize!(Δ, broadcastable(add(cast(Δ), rule(args...)))) -end - -accumulate!(Δ::Number, rule::AbstractRule, args...) = accumulate(Δ, rule, args...) - -""" - store!(Δ, rule::AbstractRule, args...) - -Compute `rule(args...)` and store the result in `Δ`, potentially avoiding -intermediate temporary allocations that might be necessary for alternative -approaches (e.g. `copyto!(Δ, extern(rule(args...)))`) - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended -to be customizable for specific rules/input types. - -See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) -""" -store!(Δ, rule::AbstractRule, args...) = materialize!(Δ, broadcastable(rule(args...))) - -##### -##### `Rule` -##### - -Cassette.@context RuleContext - -const RULE_CONTEXT = Cassette.disablehooks(RuleContext()) - -Cassette.overdub(::RuleContext, ::typeof(+), a, b) = add(a, b) -Cassette.overdub(::RuleContext, ::typeof(*), a, b) = mul(a, b) - -Cassette.overdub(::RuleContext, ::typeof(add), a, b) = add(a, b) -Cassette.overdub(::RuleContext, ::typeof(mul), a, b) = mul(a, b) - -""" - Rule(propation_function[, updating_function]) - -Return a `Rule` that wraps the given `propation_function`. It is assumed that -`propation_function` is a callable object whose arguments are differential -values, and whose output is a single differential value calculated by applying -internally stored/computed partial derivatives to the input differential -values. - -If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)` -and to store the result of the propagation function applied to the arguments `xs` into -`Δ` in-place, returning `Δ`. - -For example: - -``` -frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy) - -rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) -``` - -See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref) -""" -struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule - f::F - u::U -end - -# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some -# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}` -Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing) - -(rule::Rule{F})(args...) where {F} = Cassette.overdub(RULE_CONTEXT, rule.f, args...) - -Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))") -Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))") - -# Specialized accumulation -# TODO: Does this need to be overdubbed in the rule context? -accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) - - -""" - NO_FIELDS_RULE - -Constant for the rule for the derivative with respect to structure that has no fields. -The most notable use for this is for the reverse-mode derivative with respect to the -function itself, when that function is not a closure. -The rule returns an empty `NamedTuple` for all inputs. -""" -const NO_FIELDS_RULE = Rule(function no_fields(args...) NamedTuple() end) - -""" - ZERO_RULE - -This is a rule that returns `Zero()` regardless of input. -The most notable use for this is for the forward-mode derivative with respect to the -function itself, when that function is not a closure. -""" -const ZERO_RULE = Rule(function always_zero(args...) Zero() end) - - - -##### -##### `DNERule` -##### - -""" - DNERule(args...) - -Construct a `DNERule` object, which is an `AbstractRule` that signifies that the -current function is not differentiable with respect to a particular parameter. -**DNE** is an abbreviation for Does Not Exist. -""" -struct DNERule <: AbstractRule end - -DNERule(args...) = DNE() - -##### -##### `WirtingerRule` -##### - -""" - WirtingerRule(primal::AbstractRule, conjugate::AbstractRule) - -Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of -an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate -derivative ``∂/∂x̅``. If the domain `𝒟` of the function might be real, consider -calling `AbstractRule(𝒟, primal, conjugate)` instead, to make use of a more -efficient representation wherever possible. -""" -struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule - primal::P - conjugate::C -end - -function (rule::WirtingerRule)(args...) - return Wirtinger(rule.primal(args...), rule.conjugate(args...)) -end - -""" - AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) - -Return a `Rule` evaluating to `primal(Δ) + conjugate(Δ)` if `𝒟 <: Real`, -otherwise return `WirtingerRule(P, C)`. -""" -function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) - if 𝒟 <: Real || eltype(𝒟) <: Real - return Rule((args...) -> add(primal(args...), conjugate(args...))) - else - return WirtingerRule(primal, conjugate) - end -end - ##### ##### `frule`/`rrule` ##### diff --git a/test/rule_types.jl b/test/rule_types.jl index 79aa362cc..e7f00dfe3 100644 --- a/test/rule_types.jl +++ b/test/rule_types.jl @@ -1,5 +1,6 @@ @testset "rule types" begin + #== @testset "iterating and indexing rules" begin _, rule = frule(dummy_identity, 1) i = 0 @@ -11,7 +12,8 @@ @test rule[1] == rule @test_throws BoundsError rule[2] end - + ==# + @testset "Rule" begin @testset "show" begin @test occursin(r"^Rule\(.*foo.*\)$", repr(Rule(function foo() 1 end))) diff --git a/test/runtests.jl b/test/runtests.jl index 4e10c970b..98cda4045 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Test using ChainRulesCore using LinearAlgebra: Diagonal -using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, +using ChainRulesCore: extern, accumulate, accumulate!, store!, # @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, Zero, One, Casted, cast, DNE, Thunk, Casted, DNERule, WirtingerRule @@ -10,6 +10,6 @@ using Base.Broadcast: broadcastable @testset "ChainRulesCore" begin include("differentials.jl") - include("rules.jl") + #include("rules.jl") include("rule_types.jl") end From de2bb62d26a9220070fb458c688c9a39e30240f0 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 3 Sep 2019 11:40:25 +0100 Subject: [PATCH 05/23] WIP work on new @scalar_rule make real scalar rules work. correct @scalarrule forward rule return Wirtinger scalar working work WirtingerRule test as a test of @scalar_rule Fix spelling Co-Authored-By: simeonschaub Oxford Comma Co-Authored-By: simeonschaub spelling Co-Authored-By: Nick Robinson docstring for propagator_name spelling Co-Authored-By: Nick Robinson --- src/ChainRulesCore.jl | 3 +- src/differentials.jl | 14 +++ src/rule_definition_tools.jl | 230 +++++++++++++++++++++-------------- test/differentials.jl | 16 +++ test/rule_types.jl | 8 +- test/rules.jl | 42 ++++++- test/runtests.jl | 4 +- 7 files changed, 213 insertions(+), 104 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 5ede7da7d..3bccfdae9 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,6 +2,7 @@ module ChainRulesCore using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable export AbstractRule, Rule, frule, rrule +export wirtinger_conjugate, wirtinger_primal, differential export @scalar_rule, @thunk export extern, cast, store! export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule @@ -11,5 +12,5 @@ include("differentials.jl") include("differential_arithmetic.jl") include("rule_types.jl") include("rules.jl") -#include("rule_definition_tools.jl") +include("rule_definition_tools.jl") end # module diff --git a/src/differentials.jl b/src/differentials.jl index 9f74e2c0a..0a0768757 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -230,8 +230,22 @@ Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") """ NO_FIELDS + Constant for the reverse-mode derivative with respect to a structure that has no fields. The most notable use for this is for the reverse-mode derivative with respect to the function itself, when that function is not a closure. """ const NO_FIELDS = DNE() + +#### +""" + differential(𝒟::Type, der) + +For some differential (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), +convert it to another differential that is more suited for the domain given by +the type 𝒟. +""" +function differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) + return wirtinger_primal(w) + wirtinger_conjugate(w) +end +differential(::Any, der) = der # most of the time leave it alone. diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d68eb0410..b79dd8ece 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,5 +1,89 @@ # These are some macros (and supporting functions) to make it easier to define rules. +""" + propagator_name(f, propname) +Determines a reasonable name for the propagator function. +The name doesn't really matter too much as it is a local function to be returned +by `frule` or `rrule`, but a good name make debugging easier. +`f` should be some form of AST representation of the actual function, +`propname` should be either `:pullback` or `:pushforward` + +This is able to deal with fairly complex expressions for `f`: + + julia> propagator_name(:bar, :pushforward) + :bar_pushforward + + julia> propagator_name(esc(:(Base.Random.foo)), :pullback) + :foo_pullback +""" +propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname) +propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) +propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) + + +""" + propagation_expr(𝒟, Δs, ∂s) + + Returns the expression for the propagation of + the input gradient `Δs` though the partials `∂s`. + + 𝒟 is an expression that when evaluated returns the type-of the input domain. + For example if the derivative is being taken at the point `1` it returns `Int`. + if it is taken at `1+1im` it returns `Complex{Int}`. + At present it is ignored for non-Wirtinger derivatives. +""" +function propagation_expr(𝒟, Δs, ∂s) + wirtinger_indices = findall(∂s) do ex + Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger + end + ∂s = map(esc, ∂s) + if isempty(wirtinger_indices) + return standard_propagation_expr(Δs, ∂s) + else + return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) + end +end + +function standard_propagation_expr(Δs, ∂s) + # This is basically Δs ⋅ ∂s + + # Notice: the thunking of `∂s[i] (potentially) saves us some computation + # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon + # as the pullback is evaluated + ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] + return :(+($(∂_mul_Δs...))) +end + +function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) + ∂_mul_Δs_primal = Any[] + ∂_mul_Δs_conjugate = Any[] + ∂_wirtinger_defs = Any[] + for i in 1:length(∂s) + if i in wirtinger_indices + Δi = Δs[i] + ∂i = Symbol(string(:∂, i)) + push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) + ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) + ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) + ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) + ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) + push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) + push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) + else + ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) + push!(∂_mul_Δs_primal, ∂_mul_Δ) + push!(∂_mul_Δs_conjugate, ∂_mul_Δ) + end + end + primal_sum = :(+($(∂_mul_Δs_primal...))) + conjugate_sum = :(+($(∂_mul_Δs_conjugate...))) + return quote # This will be a block, so will have value equal to last statement + $(∂_wirtinger_defs...) + w = Wirtinger($primal_sum, $conjugate_sum) + differential($𝒟, w) + end +end + """ @scalar_rule(f(x₁, x₂, ...), @setup(statement₁, statement₂, ...), @@ -42,7 +126,7 @@ e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x At present this does not support defining for closures/functors. Thus in reverse-mode, the first returned partial, representing the derivative with respect to the function itself, is always `NO_FIELDS`. -And in forwards-mode, the first input to the returned propergator is always ignored. +And in forward-mode, the first input to the returned propagator is always ignored. The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This allows the primal result to be conveniently referenced (as `Ω`) within the @@ -69,6 +153,9 @@ For examples, see ChainRulesCore' `rules` directory. See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref) """ macro scalar_rule(call, maybe_setup, partials...) + ############################################################################ + # Setup: normalizing input form etc + if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup") setup_stmts = map(esc, maybe_setup.args[3:end]) else @@ -77,6 +164,7 @@ macro scalar_rule(call, maybe_setup, partials...) end @assert Meta.isexpr(call, :call) f = esc(call.args[1]) + # Annotate all arguments in the signature as scalars inputs = map(call.args[2:end]) do arg esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number)) @@ -90,6 +178,7 @@ macro scalar_rule(call, maybe_setup, partials...) end end + # For consistency in code that follows we make all partials tuple expressions partials = map(partials) do partial if Meta.isexpr(partial, :tuple) partial @@ -98,59 +187,58 @@ macro scalar_rule(call, maybe_setup, partials...) Expr(:tuple, partial) end end - @show partials - - ############################################################ - # Make pullback - #(TODO: move to own function) - # TODO: Wirtinger - - Δs = [Symbol(string(:Δ, i)) for i in 1:length(partials)] - pullback_returns = map(eachindex(inputs)) do input_i - ∂s = [partials.args[input_i] for partial in partials] - ∂s = map(esc, ∂s) - - # Notice: the thunking of `∂s[i] (potentially) saves us some computation - # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon - # as the pullback is evaluated - ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] - :(+($(∂_mul_Δs...))) - else - pullback = quote - function $(Symbol(nameof(f), :_pullback))($(Δs...)) - return (ChainRulesCore.NO_FIELDS, $(pullback_returns...)) + ############################################################################ + # Main body: defining the results of the frule/rrule + + # An expression that when evaluated will return the type of the input domain. + # Multiple repetitions of this expression should optimize ot. But if it does not then + # may need to move its definition into the body of the `rrule`/`frule` + 𝒟 = :(typeof(first(promote($(call.args[2:end]...))))) + + n_outputs = length(partials) + n_inputs = length(inputs) + + pushforward = let + # Δs is the input to the propagator rule + # because this is push-forward there is one per input to the function + Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] + pushforward_returns = map(1:n_outputs) do output_i + ∂s = partials[output_i].args + propagation_expr(𝒟, Δs, ∂s) end - end - ######################################## - quote - function ChainRulesCore.rrule(::typeof($f), $(inputs...)) - $(esc(:Ω)) = $call - $(setup_stmts...) - return $(esc(:Ω)), $esc(pullback) + quote + # _ is the input derivative w.r.t. function internals. since we do not + # allow closures/functors with @scalar_rule, it is always ignored + function $(propagator_name(f, :pushforward))(_, $(Δs...)) + return $(Expr(:tuple, pushforward_returns...)) + end end end -end -#== - if !all(Meta.isexpr(partial, :tuple) for partial in partials) - input_rep = :(first(promote($(inputs...)))) # stand-in with the right type for an input - forward_rules = Any[rule_from_partials(input_rep, partial.args...) for partial in partials] - reverse_rules = map(1:length(inputs) do i - reverse_partials = [partial.args[i] for partial in partials] - push!(reverse_rules, rule_from_partials(inputs[i], reverse_partials...)) + + pullback = let + # Δs is the input to the propagator rule + # because this is a pull-back there is one per output of function + Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs] + + # 1 partial derivative per input + pullback_returns = map(1:n_inputs) do input_i + ∂s = [partial.args[input_i] for partial in partials] + propagation_expr(𝒟, Δs, ∂s) + end + + quote + function $(propagator_name(f, :pullback))($(Δs...)) + return (NO_FIELDS, $(pullback_returns...)) + end end - else - @assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials) - forward_rules = Any[rule_from_partials(inputs[1], partial) for partial in partials] - reverse_rules = Any[rule_from_partials(inputs[1], partials...)] end - # First pseudo-partial is derivative WRT function itself. Since this macro does not - # support closures, it is just the empty NamedTuple - forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...) - reverse_rules = Expr(:tuple, NO_FIELDS, reverse_rules...) - return quote + ############################################################################ + # Final return: building the expression to insert in the place of this macro + + code = quote if fieldcount(typeof($f)) > 0 throw(ArgumentError( "@scalar_rule cannot be used on closures/functors (such as $f)" @@ -160,57 +248,13 @@ end function ChainRulesCore.frule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) - return $(esc(:Ω)), $forward_rules + return $(esc(:Ω)), $pushforward end + function ChainRulesCore.rrule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) - return $(esc(:Ω)), $reverse_rules - end - end -end -==# - -@macroexpand(@scalar_rule(one(x), Zero())) - - - -#== -function rule_from_partials(input_arg, ∂s...) - wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s) - ∂s = map(esc, ∂s) - Δs = [Symbol(string(:Δ, i)) for i in 1:length(∂s)] - Δs_tuple = Expr(:tuple, Δs...) - if isempty(wirtinger_indices) - ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] - return :(Rule($Δs_tuple -> +($(∂_mul_Δs...)))) - else - ∂_mul_Δs_primal = Any[] - ∂_mul_Δs_conjugate = Any[] - ∂_wirtinger_defs = Any[] - for i in 1:length(∂s) - if i in wirtinger_indices - Δi = Δs[i] - ∂i = Symbol(string(:∂, i)) - push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) - ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) - ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) - ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) - ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) - push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) - push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) - else - ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) - push!(∂_mul_Δs_primal, ∂_mul_Δ) - push!(∂_mul_Δs_conjugate, ∂_mul_Δ) - end - end - primal_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_primal...)))) - conjugate_rule = :(Rule($Δs_tuple -> +($(∂_mul_Δs_conjugate...)))) - return quote - $(∂_wirtinger_defs...) - AbstractRule(typeof($input_arg), $primal_rule, $conjugate_rule) + return $(esc(:Ω)), $pullback end end end -==# diff --git a/test/differentials.jl b/test/differentials.jl index 557c125d8..73c76da83 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -79,4 +79,20 @@ ] @test isempty(ambig_methods) end + + + @testset "Differential" begin + @test differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2) + @test differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2) + + @test differential(typeof(1.2), Wirtinger(2,2)) == 4 + @test differential(typeof([1.2]), Wirtinger(2,2)) == 4 + + # For most differentials, in most domains, this does nothing + for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) + for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2])) + @test differential(𝒟, der) === der + end + end + end end diff --git a/test/rule_types.jl b/test/rule_types.jl index e7f00dfe3..581d7fd34 100644 --- a/test/rule_types.jl +++ b/test/rule_types.jl @@ -1,8 +1,8 @@ @testset "rule types" begin - #== + # The following is deprecated and should be remove next release @testset "iterating and indexing rules" begin - _, rule = frule(dummy_identity, 1) + rule = Rule(identity) i = 0 for r in rule @test r === rule @@ -12,8 +12,8 @@ @test rule[1] == rule @test_throws BoundsError rule[2] end - ==# - + + @testset "Rule" begin @testset "show" begin @test occursin(r"^Rule\(.*foo.*\)$", repr(Rule(function foo() 1 end))) diff --git a/test/rules.jl b/test/rules.jl index 7295da7a3..d7b5f19b2 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -6,7 +6,6 @@ cool(x, y) = x + y + 1 # a rule we define so we can test rules dummy_identity(x) = x - @scalar_rule(dummy_identity(x), One()) ####### @@ -19,6 +18,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrule(cool, 1) === nothing @test rrule(cool, 1; iscool=true) === nothing + # add some methods: ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) @test hasmethod(rrule, Tuple{typeof(cool),Number}) ChainRulesCore.@scalar_rule(Main.cool(x::String), "wow such dfdx") @@ -31,8 +31,42 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) frx, fr = frule(cool, 1) @test frx == 2 - @test fr(1) == 1 - rrx, rr = rrule(cool, 1) + @test fr(NamedTuple(), 1) == (1,) + rrx, (rr) = rrule(cool, 1) + self, rr1 = rr(1) + @test self == NO_FIELDS @test rrx == 2 - @test rr(1) == 1 + @test rr1 == 1 +end + + +@testset "Wirtinger scalar_rule" begin + myabs2(x) = abs2(x) + @scalar_rule(myabs2(x), Wirtinger(x', x)) + + # real input + x = rand(Float64) + f, pushforward = frule(myabs2, x) + @test f === x^2 + + df = @inferred pushforward(NamedTuple(), One()) + @test df === (x + x,) + + + Δ = rand(Complex{Int64}) + df = @inferred pushforward(NamedTuple(), Δ) + @test df === (Δ * (x + x),) + + + # complex input + z = rand(Complex{Float64}) + f, pushforward = frule(myabs2, z) + @test f === abs2(z) + + df = @inferred pushforward(NamedTuple(), One()) + @test df === (Wirtinger(z', z),) + + Δ = rand(Complex{Int64}) + df = @inferred pushforward(NamedTuple(), Δ) + @test df === (Wirtinger(Δ * z', Δ * z),) end diff --git a/test/runtests.jl b/test/runtests.jl index 98cda4045..4e10c970b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Test using ChainRulesCore using LinearAlgebra: Diagonal -using ChainRulesCore: extern, accumulate, accumulate!, store!, # @scalar_rule, +using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, Zero, One, Casted, cast, DNE, Thunk, Casted, DNERule, WirtingerRule @@ -10,6 +10,6 @@ using Base.Broadcast: broadcastable @testset "ChainRulesCore" begin include("differentials.jl") - #include("rules.jl") + include("rules.jl") include("rule_types.jl") end From eb3c292b0c7a675b5ac8e939583e4f5ccc734511 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 3 Sep 2019 14:10:14 +0100 Subject: [PATCH 06/23] Update test/rules.jl error ratehr than Assert cleanup Update src/rule_definition_tools.jl Co-Authored-By: Nick Robinson Add more complex Wirtinger Scalar Rule Test --- src/differentials.jl | 7 +- src/rule_definition_tools.jl | 4 +- test/rules.jl | 125 ++++++++++++++++++++++++++++------- 3 files changed, 108 insertions(+), 28 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index 0a0768757..687917b74 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -241,9 +241,10 @@ const NO_FIELDS = DNE() """ differential(𝒟::Type, der) -For some differential (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), -convert it to another differential that is more suited for the domain given by -the type 𝒟. +Converts, if required, a differential object `der` +(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), +to another differential that is more suited for the domain given by the type 𝒟. +Often this will behave as the identity function on `der`. """ function differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) return wirtinger_primal(w) + wirtinger_conjugate(w) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index b79dd8ece..4475ac373 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -183,7 +183,7 @@ macro scalar_rule(call, maybe_setup, partials...) if Meta.isexpr(partial, :tuple) partial else - @assert length(inputs) == 1 + length(inputs) == 1 || error("Invalid use of `@scalar_rule`") Expr(:tuple, partial) end end @@ -192,7 +192,7 @@ macro scalar_rule(call, maybe_setup, partials...) # Main body: defining the results of the frule/rrule # An expression that when evaluated will return the type of the input domain. - # Multiple repetitions of this expression should optimize ot. But if it does not then + # Multiple repetitions of this expression should optimize out. But if it does not then # may need to move its definition into the body of the `rrule`/`frule` 𝒟 = :(typeof(first(promote($(call.args[2:end]...))))) diff --git a/test/rules.jl b/test/rules.jl index d7b5f19b2..e4a6e1c0c 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -29,44 +29,123 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) Tuple{typeof(rrule),typeof(cool),String}]) @test cool_methods == only_methods - frx, fr = frule(cool, 1) + frx, cool_pushforward = frule(cool, 1) @test frx == 2 - @test fr(NamedTuple(), 1) == (1,) - rrx, (rr) = rrule(cool, 1) - self, rr1 = rr(1) + @test cool_pushforward(NamedTuple(), 1) == (1,) + rrx, cool_pullback = rrule(cool, 1) + self, rr1 = cool_pullback(1) @test self == NO_FIELDS @test rrx == 2 @test rr1 == 1 end -@testset "Wirtinger scalar_rule" begin +@testset "Basic Wirtinger scalar_rule" begin myabs2(x) = abs2(x) @scalar_rule(myabs2(x), Wirtinger(x', x)) - # real input - x = rand(Float64) - f, pushforward = frule(myabs2, x) - @test f === x^2 + @testset "real input" begin + # even though our rule was define in terms of Wirtinger, + # pushforward result will be real as real (even if seed is Compex) - df = @inferred pushforward(NamedTuple(), One()) - @test df === (x + x,) + x = rand(Float64) + f, myabs2_pushforward = frule(myabs2, x) + @test f === x^2 + Δ = One() + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === (x + x,) - Δ = rand(Complex{Int64}) - df = @inferred pushforward(NamedTuple(), Δ) - @test df === (Δ * (x + x),) + Δ = rand(Complex{Int64}) + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === (Δ * (x + x),) + end + @testset "complex input" begin + z = rand(Complex{Float64}) + f, myabs2_pushforward = frule(myabs2, z) + @test f === abs2(z) - # complex input - z = rand(Complex{Float64}) - f, pushforward = frule(myabs2, z) - @test f === abs2(z) + df = @inferred myabs2_pushforward(NamedTuple(), One()) + @test df === (Wirtinger(z', z),) + + Δ = rand(Complex{Int64}) + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === (Wirtinger(Δ * z', Δ * z),) + end +end - df = @inferred pushforward(NamedTuple(), One()) - @test df === (Wirtinger(z', z),) - Δ = rand(Complex{Int64}) - df = @inferred pushforward(NamedTuple(), Δ) - @test df === (Wirtinger(Δ * z', Δ * z),) +@testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin + # This is based on SimeonSchaub excellent example: + # https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 + + # This is much more complex than the previous case + # as it has many different types + # depending on input, and the output types do not always agree + + abs_to_pow(x, p) = abs(x)^p + @scalar_rule( + abs_to_pow(x::Real, p), + ( + p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), + Ω * log(abs(x)) + ) + ) + + @scalar_rule( + abs_to_pow(x::Complex, p), + @setup(u = abs(x)), + ( + p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), + Ω * log(abs(x)) + ) + ) + + + f = abs_to_pow + @testset "f($x, $p)" for (x, p) in Iterators.product( + (2, 3.4, -2.1, -10+0im, 2.3-2im), + (0, 1, 2, 4.3, -2.1, 1+.2im) + ) + expected_type_df_dx = + if iszero(p) + Zero + elseif typeof(x) <: Complex + Wirtinger + elseif typeof(p) <: Complex + Complex + else + Real + end + + expected_type_df_dp = + if typeof(p) <: Real + Real + else + Complex + end + + + res = frule(f, x, p) + @test res !== nothing # Check the rule was defined + fx, f_pushforward = res + df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) + + df_dx, = df(One(), Zero()) + df_dp,= df(Zero(), One()) + @test fx == f(x, p) # Check we still get the normal value, right + @test extern(df_dx) isa expected_type_df_dx + @test extern(df_dp) isa expected_type_df_dp + + + res = rrule(f, x, p) + @test res !== nothing # Check the rule was defined + fx, f_pullback = res + dself, df_dx, df_dp = f_pullback(One()) + @test fx == f(x, p) # Check we still get the normal value, right + @test dself == NO_FIELDS + @test extern(df_dx) isa expected_type_df_dx + @test extern(df_dp) isa expected_type_df_dp + end end From f6979ac675f61c274ada6978ab453795878c1d06 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 3 Sep 2019 20:00:10 +0100 Subject: [PATCH 07/23] Make accumulate apply to Differentials update accumulate to work on differentials --- src/rule_types.jl | 121 +++++++++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 56 deletions(-) diff --git a/src/rule_types.jl b/src/rule_types.jl index 5364f3751..fe14fbe9c 100644 --- a/src/rule_types.jl +++ b/src/rule_types.jl @@ -61,59 +61,6 @@ function Base.getindex(rule::AbstractRule, i::Integer) return i == 1 ? rule : throw(BoundsError()) end -""" - accumulate(Δ, rule::AbstractRule, args...) - -Return `Δ + rule(args...)` evaluated in a manner that supports ChainRulesCore' -various `AbstractDifferential` types. - -This method intended to be customizable for specific rules/input types. For -example, here is pseudocode to overload `accumulate` w.r.t. a specific forward -differentiation rule for a given function `f`: - -``` -df(x) = # forward differentiation primitive implementation - -frule(::typeof(f), x) = (f(x), Rule(df)) - -accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation -``` - -See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -accumulate(Δ, rule, args...) = Δ + rule(args...) - -""" - accumulate!(Δ, rule::AbstractRule, args...) - -Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place, -storing the result in `Δ`. - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -function accumulate!(Δ, rule, args...) - return materialize!(Δ, broadcastable(cast(Δ) + rule(args...))) -end - -accumulate!(Δ::Number, rule, args...) = accumulate(Δ, rule, args...) - -""" - store!(Δ, rule::AbstractRule, args...) - -Compute `rule(args...)` and store the result in `Δ`, potentially avoiding -intermediate temporary allocations that might be necessary for alternative -approaches (e.g. `copyto!(Δ, extern(rule(args...)))`) - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended -to be customizable for specific rules/input types. - -See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) -""" -store!(Δ, rule, args...) = materialize!(Δ, broadcastable(rule(args...))) ##### ##### `Rule` @@ -157,9 +104,6 @@ Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing) Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))") Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))") -# Specialized accumulation -# TODO: Does this need to be overdubbed in the rule context? -accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) ##### ##### `DNERule` @@ -211,3 +155,68 @@ function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) return WirtingerRule(primal, conjugate) end end + + +""" + accumulate(Δ, ∂) + +Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore' +various `AbstractDifferential` types. + +#TODO: update these docs + +This method intended to be customizable for specific rules/input types. For +example, here is pseudocode to overload `accumulate` w.r.t. a specific forward +differentiation rule for a given function `f`: + +``` +df(x) = # forward differentiation primitive implementation + +frule(::typeof(f), x) = (f(x), Rule(df)) + +accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation +``` + +See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +""" +accumulate(Δ, ∂) = Δ + ∂ + +""" + accumulate!(Δ, rule::AbstractRule, args...) + +# TODO: Update these docs + +Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place, +storing the result in `Δ`. + +Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. + +See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +""" +function accumulate!(Δ, ∂) + return materialize!(Δ, broadcastable(cast(Δ) + ∂)) +end + +accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂) + +# TODO: replace this: +# accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) + + +""" + store!(Δ, ∂) + +TODO: Rewrite these docs + +Compute `rule(args...)` and store the result in `Δ`, potentially avoiding +intermediate temporary allocations that might be necessary for alternative +approaches (e.g. `copyto!(Δ, extern(rule(args...)))`) + +Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. + +Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended +to be customizable for specific rules/input types. + +See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) +""" +store!(Δ, ∂) = materialize!(Δ, broadcastable(∂)) From 55dcefe0f09ce94ae158c945e6008cc1a074a253 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 3 Sep 2019 21:55:50 +0100 Subject: [PATCH 08/23] Update src/rule_definition_tools.jl Co-Authored-By: Curtis Vogt --- src/rule_definition_tools.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 4475ac373..b1ba08113 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -2,6 +2,7 @@ """ propagator_name(f, propname) + Determines a reasonable name for the propagator function. The name doesn't really matter too much as it is a local function to be returned by `frule` or `rrule`, but a good name make debugging easier. From 53d5f9ed61b83ffe7d47181ad2556defe1d6a33a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 4 Sep 2019 17:33:44 +0100 Subject: [PATCH 09/23] Make frule return a scalar --- src/rule_definition_tools.jl | 9 +++++++-- test/rules.jl | 14 +++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index b1ba08113..5724412c5 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -208,12 +208,17 @@ macro scalar_rule(call, maybe_setup, partials...) ∂s = partials[output_i].args propagation_expr(𝒟, Δs, ∂s) end - + if n_outputs > 1 + # For forward-mode we only return a tuple if output actually a tuple. + pushforward_returns = Expr(:tuple, pushforward_returns...) + else + pushforward_returns = pushforward_returns[1] + end quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored function $(propagator_name(f, :pushforward))(_, $(Δs...)) - return $(Expr(:tuple, pushforward_returns...)) + $pushforward_returns end end end diff --git a/test/rules.jl b/test/rules.jl index e4a6e1c0c..59febb8a4 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -31,7 +31,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) frx, cool_pushforward = frule(cool, 1) @test frx == 2 - @test cool_pushforward(NamedTuple(), 1) == (1,) + @test cool_pushforward(NamedTuple(), 1) == 1 rrx, cool_pullback = rrule(cool, 1) self, rr1 = cool_pullback(1) @test self == NO_FIELDS @@ -54,11 +54,11 @@ end Δ = One() df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === (x + x,) + @test df === x + x Δ = rand(Complex{Int64}) df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === (Δ * (x + x),) + @test df === Δ * (x + x) end @testset "complex input" begin @@ -67,11 +67,11 @@ end @test f === abs2(z) df = @inferred myabs2_pushforward(NamedTuple(), One()) - @test df === (Wirtinger(z', z),) + @test df === Wirtinger(z', z) Δ = rand(Complex{Int64}) df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === (Wirtinger(Δ * z', Δ * z),) + @test df === Wirtinger(Δ * z', Δ * z) end end @@ -132,8 +132,8 @@ end fx, f_pushforward = res df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) - df_dx, = df(One(), Zero()) - df_dp,= df(Zero(), One()) + df_dx = df(One(), Zero()) + df_dp = df(Zero(), One()) @test fx == f(x, p) # Check we still get the normal value, right @test extern(df_dx) isa expected_type_df_dx @test extern(df_dp) isa expected_type_df_dp From 5b6c0d836ac1c0c6b2f6fb1b8fdce966f06ecce6 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 15:50:36 +0100 Subject: [PATCH 10/23] add docs about when to thunk --- src/differentials.jl | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/differentials.jl b/src/differentials.jl index 687917b74..f85f7665b 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -181,8 +181,10 @@ Base.iterate(::One, ::Any) = nothing Thunk(()->v) A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a differential. +`@thunk(v)` is a macro that expands into `Thunk(()->v)`. -Calling that thunk, calls the wrapped closure. + +Calling a thunk, calls the wrapped closure. `extern`ing thunks applies recursively, it also externs the differial that the closure returns. If you do not want that, then simply call the thunk @@ -199,6 +201,22 @@ Thunk(var"##8#10"()) julia> t()() 3 ``` + +### When to `@thunk`? +When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk` +appropriately. Propagation rule's that return multiple deriviatives are able to not +do all the work of computing all of them only to have just one used. +By `@thunk`ing the work required for each, they can only be computed when needed. + +#### So why not thunk everything? +`@thunk` creates a closure over the expression, which is basically a struct with a +field for each variable used in the expression (closed over), and call overloaded. +If this would be equal or more work than actually evaluating the expression then don't do +it. An example would be if the expression itself is just wrapping something in a struct. +Such as `Adjoint(x)` or `Diagonal(x)`. Or if the expression is a constant, or is +itself a `Thunk`. +If you got the expression from another `rrule` (or `frule`), you don't need to +`@thunk` it since it will have been thunked if required, by the defining rule. """ struct Thunk{F} <: AbstractDifferential f::F From 5c3fdaa4e646412c262ac52ed65fc409754e139b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 6 Sep 2019 17:26:11 +0100 Subject: [PATCH 11/23] fix up for recursive extern --- test/rules.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index 59febb8a4..e23680326 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -132,11 +132,11 @@ end fx, f_pushforward = res df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) - df_dx = df(One(), Zero()) - df_dp = df(Zero(), One()) + df_dx::Thunk = df(One(), Zero()) + df_dp::Thunk = df(Zero(), One()) @test fx == f(x, p) # Check we still get the normal value, right - @test extern(df_dx) isa expected_type_df_dx - @test extern(df_dp) isa expected_type_df_dp + @test df_dx() isa expected_type_df_dx + @test df_dp() isa expected_type_df_dp res = rrule(f, x, p) @@ -145,7 +145,7 @@ end dself, df_dx, df_dp = f_pullback(One()) @test fx == f(x, p) # Check we still get the normal value, right @test dself == NO_FIELDS - @test extern(df_dx) isa expected_type_df_dx - @test extern(df_dp) isa expected_type_df_dp + @test df_dx() isa expected_type_df_dx + @test df_dp() isa expected_type_df_dp end end From f14e045136626beba9eb9847ae87fd869997eb22 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 10 Sep 2019 15:05:05 +0100 Subject: [PATCH 12/23] correct accumulate to be broadcasting --- src/rule_types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_types.jl b/src/rule_types.jl index fe14fbe9c..305b18d9f 100644 --- a/src/rule_types.jl +++ b/src/rule_types.jl @@ -179,7 +179,7 @@ accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementa See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) """ -accumulate(Δ, ∂) = Δ + ∂ +accumulate(Δ, ∂) = Δ .+ ∂ """ accumulate!(Δ, rule::AbstractRule, args...) From ef748c9b8f5d2c44b8ec70ea69d17999098d80a5 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 11 Sep 2019 14:39:26 +0100 Subject: [PATCH 13/23] Get rid of all Rule types, add InplacableThunk spelling is hard --- src/ChainRulesCore.jl | 6 +- src/differential_arithmetic.jl | 25 ++-- src/differentials.jl | 38 +++++- src/operations.jl | 45 +++++++ src/rule_types.jl | 222 --------------------------------- test/rule_types.jl | 63 ---------- test/runtests.jl | 4 +- 7 files changed, 104 insertions(+), 299 deletions(-) create mode 100644 src/operations.jl delete mode 100644 src/rule_types.jl delete mode 100644 test/rule_types.jl diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 3bccfdae9..275029cb8 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -1,16 +1,16 @@ module ChainRulesCore using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable -export AbstractRule, Rule, frule, rrule +export frule, rrule export wirtinger_conjugate, wirtinger_primal, differential export @scalar_rule, @thunk export extern, cast, store! -export Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule +export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk export NO_FIELDS include("differentials.jl") include("differential_arithmetic.jl") -include("rule_types.jl") +include("operations.jl") include("rules.jl") include("rule_definition_tools.jl") end # module diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index d7a78777c..ad649797e 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered. Thus we can avoid any ambiguities. Notice: - The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :Any) + The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any) Thus each of the @eval loops creating definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger) return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) end -for T in (:Casted, :Zero, :DNE, :One, :Thunk, :Any) +for T in (:Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any) @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b @@ -47,7 +47,7 @@ end Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value)) Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value)) -for T in (:Zero, :DNE, :One, :Thunk, :Any) +for T in (:Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any) @eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b)) @eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value)) @@ -58,7 +58,7 @@ end Base.:+(::Zero, b::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() -for T in (:DNE, :One, :Thunk, :Any) +for T in (:DNE, :One, :Thunk, :InplaceableThunk, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a @@ -69,7 +69,7 @@ end Base.:+(::DNE, ::DNE) = DNE() Base.:*(::DNE, ::DNE) = DNE() -for T in (:One, :Thunk, :Any) +for T in (:One, :Thunk, :InplaceableThunk, :Any) @eval Base.:+(::DNE, b::$T) = b @eval Base.:+(a::$T, ::DNE) = a @@ -80,7 +80,7 @@ end Base.:+(a::One, b::One) = extern(a) + extern(b) Base.:*(::One, ::One) = One() -for T in (:Thunk, :Any) +for T in (:Thunk, :InplaceableThunk, :Any) @eval Base.:+(a::One, b::$T) = extern(a) + b @eval Base.:+(a::$T, b::One) = a + extern(b) @@ -91,10 +91,21 @@ end Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b) Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b) -for T in (:Any,) #This loop is redundant but for consistency... +for T in (:InplaceableThunk, :Any) @eval Base.:+(a::Thunk, b::$T) = extern(a) + b @eval Base.:+(a::$T, b::Thunk) = a + extern(b) @eval Base.:*(a::Thunk, b::$T) = extern(a) * b @eval Base.:*(a::$T, b::Thunk) = a * extern(b) end + +# InplaceableThunk acts just like Thunk +Base.:+(a::InplaceableThunk, b::InplaceableThunk) = extern(a) + extern(b) +Base.:*(a::InplaceableThunk, b::InplaceableThunk) = extern(a) * extern(b) +for T in (:Any, ) + @eval Base.:+(a::InplaceableThunk, b::$T) = extern(a) + b + @eval Base.:+(a::$T, b::InplaceableThunk) = a + extern(b) + + @eval Base.:*(a::InplaceableThunk, b::$T) = extern(a) * b + @eval Base.:*(a::$T, b::InplaceableThunk) = a * extern(b) +end diff --git a/src/differentials.jl b/src/differentials.jl index f85f7665b..7fc7e710c 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -246,6 +246,43 @@ Base.conj(x::Thunk) = @thunk(conj(extern(x))) Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") +""" + InplaceableThunk(val::Thunk, add!::Function) + +A wrapper for a `Thunk`, that allows it to define an inplace `add!` function, +which is used internally in `accumulate!(Δ, ::InplaceableThunk)`. + +`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val` +but it should do this more efficently than simply doing this directly. +(Otherwise one can just use a normal `Thunk`). + +Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; +and destroy its inplacability. +""" +struct InplaceableThunk{T<:Thunk, F} <: AbstractDifferential + val::T + add!::F +end + +(x::InplaceableThunk)() = x.val() +@inline extern(x::InplaceableThunk) = extern(x.val) + +Base.Broadcast.broadcastable(x::InplaceableThunk) = broadcastable(x.val) + +@inline function Base.iterate(x::InplaceableThunk, args...) + return iterate(x.val, args...) +end + +Base.conj(x::InplaceableThunk) = conj(x.val) + +function Base.show(io::IO, x::InplaceableThunk) + println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") +end + +# The real reason we have this: +accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) + + """ NO_FIELDS @@ -255,7 +292,6 @@ function itself, when that function is not a closure. """ const NO_FIELDS = DNE() -#### """ differential(𝒟::Type, der) diff --git a/src/operations.jl b/src/operations.jl new file mode 100644 index 000000000..33de43efa --- /dev/null +++ b/src/operations.jl @@ -0,0 +1,45 @@ +# TODO: This all needs a fair bit of rethinking + +""" + accumulate(Δ, ∂) + +Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's +various `AbstractDifferential` types. + +See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +""" +accumulate(Δ, ∂) = Δ .+ ∂ + +""" + accumulate!(Δ, ∂) + +Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place, +storing the result in `Δ`. + +Note: this function may not actually store the result in `Δ` if `Δ` is immutable, +so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case. + +This function is overloadable by [`InplaceableThunk`s](@ref). +See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +""" +function accumulate!(Δ, ∂) + return materialize!(Δ, broadcastable(cast(Δ) + ∂)) +end + +accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂) + + + +""" + store!(Δ, ∂) + +Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. +potentially avoiding intermediate temporary allocations that might be +necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`) + +Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended +to be customizable for specific rules/input types. + +See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) +""" +store!(Δ, ∂) = materialize!(Δ, broadcastable(∂)) diff --git a/src/rule_types.jl b/src/rule_types.jl deleted file mode 100644 index 305b18d9f..000000000 --- a/src/rule_types.jl +++ /dev/null @@ -1,222 +0,0 @@ -""" -Subtypes of `AbstractRule` are types which represent the primitive derivative -propagation "rules" that can be composed to implement forward- and reverse-mode -automatic differentiation. - -More specifically, a `rule::AbstractRule` is a callable Julia object generally -obtained via calling [`frule`](@ref) or [`rrule`](@ref). Such rules accept -differential values as input, evaluate the chain rule using internally stored/ -computed partial derivatives to produce a single differential value, then -return that calculated differential value. - -For example: - -``` -julia> using ChainRulesCore: frule, rrule, AbstractRule - -julia> x, y = rand(2); - -julia> h, dh = frule(hypot, x, y); - -julia> h == hypot(x, y) -true - -julia> isa(dh, AbstractRule) -true - -julia> Δx, Δy = rand(2); - -julia> dh(Δx, Δy) == ((x / h) * Δx + (y / h) * Δy) -true - -julia> h, (dx, dy) = rrule(hypot, x, y); - -julia> h == hypot(x, y) -true - -julia> isa(dx, AbstractRule) && isa(dy, AbstractRule) -true - -julia> Δh = rand(); - -julia> dx(Δh) == (x / h) * Δh -true - -julia> dy(Δh) == (y / h) * Δh -true -``` - -See also: [`frule`](@ref), [`rrule`](@ref), [`Rule`](@ref), [`DNERule`](@ref), [`WirtingerRule`](@ref) -""" -abstract type AbstractRule end - -# this ensures that consumers don't have to special-case rule destructuring -Base.iterate(rule::AbstractRule) = (@warn "iterating rules is going away"; (rule, nothing)) -Base.iterate(::AbstractRule, ::Any) = nothing - -# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple -# in order to get the `i`th rule (assuming it's 1) -function Base.getindex(rule::AbstractRule, i::Integer) - @warn "iterating rules is going away" - return i == 1 ? rule : throw(BoundsError()) -end - - -##### -##### `Rule` -##### - - -""" - Rule(propation_function[, updating_function]) - -Return a `Rule` that wraps the given `propation_function`. It is assumed that -`propation_function` is a callable object whose arguments are differential -values, and whose output is a single differential value calculated by applying -internally stored/computed partial derivatives to the input differential -values. - -If an updating function is provided, it is assumed to have the signature `u(Δ, xs...)` -and to store the result of the propagation function applied to the arguments `xs` into -`Δ` in-place, returning `Δ`. - -For example: - -``` -frule(::typeof(*), x, y) = x * y, Rule((Δx, Δy) -> Δx * y + x * Δy) - -rrule(::typeof(*), x, y) = x * y, (Rule(ΔΩ -> ΔΩ * y'), Rule(ΔΩ -> x' * ΔΩ)) -``` - -See also: [`frule`](@ref), [`rrule`](@ref), [`accumulate`](@ref), [`accumulate!`](@ref), [`store!`](@ref) -""" -struct Rule{F,U<:Union{Function,Nothing}} <: AbstractRule - f::F - u::U -end - -# NOTE: Using `Core.Typeof` instead of `typeof` here so that if we define a rule for some -# constructor based on a `UnionAll`, we get `Rule{Type{Thing}}` instead of `Rule{UnionAll}` -Rule(f) = Rule{Core.Typeof(f),Nothing}(f, nothing) - -(rule::Rule)(args...) = rule.f(args...) - -Base.show(io::IO, rule::Rule{<:Any, Nothing}) = print(io, "Rule($(rule.f))") -Base.show(io::IO, rule::Rule) = print(io, "Rule($(rule.f), $(rule.u))") - - -##### -##### `DNERule` -##### - -""" - DNERule(args...) - -Construct a `DNERule` object, which is an `AbstractRule` that signifies that the -current function is not differentiable with respect to a particular parameter. -**DNE** is an abbreviation for Does Not Exist. -""" -struct DNERule <: AbstractRule end - -DNERule(args...) = DNE() - -##### -##### `WirtingerRule` -##### - -""" - WirtingerRule(primal::AbstractRule, conjugate::AbstractRule) - -Construct a `WirtingerRule` object, which is an `AbstractRule` that consists of -an `AbstractRule` for both the primal derivative ``∂/∂x`` and the conjugate -derivative ``∂/∂x̅``. If the domain `𝒟` of the function might be real, consider -calling `AbstractRule(𝒟, primal, conjugate)` instead, to make use of a more -efficient representation wherever possible. -""" -struct WirtingerRule{P<:AbstractRule,C<:AbstractRule} <: AbstractRule - primal::P - conjugate::C -end - -function (rule::WirtingerRule)(args...) - return Wirtinger(rule.primal(args...), rule.conjugate(args...)) -end - -""" - AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) - -Return a `Rule` evaluating to `primal(Δ) + conjugate(Δ)` if `𝒟 <: Real`, -otherwise return `WirtingerRule(P, C)`. -""" -function AbstractRule(𝒟::Type, primal::AbstractRule, conjugate::AbstractRule) - if 𝒟 <: Real || eltype(𝒟) <: Real - return Rule((args...) -> (primal(args...) + conjugate(args...))) - else - return WirtingerRule(primal, conjugate) - end -end - - -""" - accumulate(Δ, ∂) - -Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore' -various `AbstractDifferential` types. - -#TODO: update these docs - -This method intended to be customizable for specific rules/input types. For -example, here is pseudocode to overload `accumulate` w.r.t. a specific forward -differentiation rule for a given function `f`: - -``` -df(x) = # forward differentiation primitive implementation - -frule(::typeof(f), x) = (f(x), Rule(df)) - -accumulate(Δ, rule::Rule{typeof(df)}, x) = # customized `accumulate` implementation -``` - -See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -accumulate(Δ, ∂) = Δ .+ ∂ - -""" - accumulate!(Δ, rule::AbstractRule, args...) - -# TODO: Update these docs - -Similar to [`accumulate`](@ref), but compute `Δ + rule(args...)` in-place, -storing the result in `Δ`. - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) -""" -function accumulate!(Δ, ∂) - return materialize!(Δ, broadcastable(cast(Δ) + ∂)) -end - -accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂) - -# TODO: replace this: -# accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, args...) - - -""" - store!(Δ, ∂) - -TODO: Rewrite these docs - -Compute `rule(args...)` and store the result in `Δ`, potentially avoiding -intermediate temporary allocations that might be necessary for alternative -approaches (e.g. `copyto!(Δ, extern(rule(args...)))`) - -Note that this function internally calls `Base.Broadcast.materialize!(Δ, ...)`. - -Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended -to be customizable for specific rules/input types. - -See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) -""" -store!(Δ, ∂) = materialize!(Δ, broadcastable(∂)) diff --git a/test/rule_types.jl b/test/rule_types.jl deleted file mode 100644 index 581d7fd34..000000000 --- a/test/rule_types.jl +++ /dev/null @@ -1,63 +0,0 @@ - -@testset "rule types" begin - # The following is deprecated and should be remove next release - @testset "iterating and indexing rules" begin - rule = Rule(identity) - i = 0 - for r in rule - @test r === rule - i += 1 - end - @test i == 1 # rules only iterate once, yielding themselves - @test rule[1] == rule - @test_throws BoundsError rule[2] - end - - - @testset "Rule" begin - @testset "show" begin - @test occursin(r"^Rule\(.*foo.*\)$", repr(Rule(function foo() 1 end))) - @test occursin(r"^Rule\(.*identity.*\)$", repr(Rule(identity))) - - @test occursin(r"^Rule\(.*identity.*\,.*\+.*\)$", repr(Rule(identity, +))) - end - end - - @testset "WirtingerRule" begin - myabs2(x) = abs2(x) - - function ChainRulesCore.frule(::typeof(myabs2), x) - return abs2(x), AbstractRule( - typeof(x), - Rule(Δx -> Δx * x'), - Rule(Δx -> Δx * x) - ) - end - - # real input - x = rand(Float64) - f, _df = @inferred frule(myabs2, x) - @test f === x^2 - - df = @inferred _df(One()) - @test df === x + x - - - Δ = rand(Complex{Int64}) - df = @inferred _df(Δ) - @test df === Δ * (x + x) - - - # complex input - z = rand(Complex{Float64}) - f, _df = @inferred frule(myabs2, z) - @test f === abs2(z) - - df = @inferred _df(One()) - @test df === Wirtinger(z', z) - - Δ = rand(Complex{Int64}) - df = @inferred _df(Δ) - @test df === Wirtinger(Δ * z', Δ * z) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 4e10c970b..5f8552935 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,12 +4,10 @@ using ChainRulesCore using LinearAlgebra: Diagonal using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, - Zero, One, Casted, cast, - DNE, Thunk, Casted, DNERule, WirtingerRule + Zero, One, Casted, cast, DNE, Thunk using Base.Broadcast: broadcastable @testset "ChainRulesCore" begin include("differentials.jl") include("rules.jl") - include("rule_types.jl") end From 48a039127e06e06390fb1f092f8f7fb1a3d8d8af Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 11 Sep 2019 18:44:08 +0100 Subject: [PATCH 14/23] overload store! zero the storage inplace --- src/differentials.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/differentials.jl b/src/differentials.jl index 7fc7e710c..412bb4c8d 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -281,7 +281,7 @@ end # The real reason we have this: accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) - +store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero i, then add to it. """ NO_FIELDS From 752732e5ae9f9a1beed60d3b63b38a303d7cd41e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 16 Sep 2019 15:45:51 +0100 Subject: [PATCH 15/23] much AbstractThunk --- src/differential_arithmetic.jl | 37 +++++++++--------------- src/differentials.jl | 51 +++++++++++++++++----------------- 2 files changed, 38 insertions(+), 50 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index ad649797e..e65748d34 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered. Thus we can avoid any ambiguities. Notice: - The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any) + The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) Thus each of the @eval loops creating definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger) return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) end -for T in (:Casted, :Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any) +for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b @@ -47,7 +47,7 @@ end Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value)) Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value)) -for T in (:Zero, :DNE, :One, :Thunk, :InplaceableThunk, :Any) +for T in (:Zero, :DNE, :One, :AbstractThunk, :Any) @eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b)) @eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value)) @@ -58,7 +58,7 @@ end Base.:+(::Zero, b::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() -for T in (:DNE, :One, :Thunk, :InplaceableThunk, :Any) +for T in (:DNE, :One, :AbstractThunk, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a @@ -69,7 +69,7 @@ end Base.:+(::DNE, ::DNE) = DNE() Base.:*(::DNE, ::DNE) = DNE() -for T in (:One, :Thunk, :InplaceableThunk, :Any) +for T in (:One, :AbstractThunk, :Any) @eval Base.:+(::DNE, b::$T) = b @eval Base.:+(a::$T, ::DNE) = a @@ -80,7 +80,7 @@ end Base.:+(a::One, b::One) = extern(a) + extern(b) Base.:*(::One, ::One) = One() -for T in (:Thunk, :InplaceableThunk, :Any) +for T in (:AbstractThunk, :Any) @eval Base.:+(a::One, b::$T) = extern(a) + b @eval Base.:+(a::$T, b::One) = a + extern(b) @@ -89,23 +89,12 @@ for T in (:Thunk, :InplaceableThunk, :Any) end -Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b) -Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b) -for T in (:InplaceableThunk, :Any) - @eval Base.:+(a::Thunk, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::Thunk) = a + extern(b) +Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b) +Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b) +for T in (:Any,) + @eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b + @eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b) - @eval Base.:*(a::Thunk, b::$T) = extern(a) * b - @eval Base.:*(a::$T, b::Thunk) = a * extern(b) -end - -# InplaceableThunk acts just like Thunk -Base.:+(a::InplaceableThunk, b::InplaceableThunk) = extern(a) + extern(b) -Base.:*(a::InplaceableThunk, b::InplaceableThunk) = extern(a) * extern(b) -for T in (:Any, ) - @eval Base.:+(a::InplaceableThunk, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::InplaceableThunk) = a + extern(b) - - @eval Base.:*(a::InplaceableThunk, b::$T) = extern(a) * b - @eval Base.:*(a::$T, b::InplaceableThunk) = a * extern(b) + @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b + @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) end diff --git a/src/differentials.jl b/src/differentials.jl index 412bb4c8d..b70600387 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing) Base.iterate(::One, ::Any) = nothing +##### +##### `AbstractThunk +##### +abstract type AbstractThunk <: AbstractDifferential end + +Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x)) + +@inline function Base.iterate(x::AbstractThunk) + externed = extern(x) + element, state = iterate(externed) + return element, (externed, state) +end + +@inline function Base.iterate(::AbstractThunk, (externed, state)) + element, new_state = iterate(externed, state) + return element, (externed, new_state) +end + ##### ##### `Thunk` ##### @@ -218,7 +236,7 @@ itself a `Thunk`. If you got the expression from another `rrule` (or `frule`), you don't need to `@thunk` it since it will have been thunked if required, by the defining rule. """ -struct Thunk{F} <: AbstractDifferential +struct Thunk{F} <: AbstractThunk f::F end @@ -226,23 +244,12 @@ macro thunk(body) return :(Thunk(() -> $(esc(body)))) end -(x::Thunk)() = x.f() -@inline extern(x::Thunk) = extern(x()) - -Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x)) +# have to define this here after `@thunk` and `Thunk` is defined +Base.conj(x::AbstractThunk) = @thunk(conj(extern(x))) -@inline function Base.iterate(x::Thunk) - externed = extern(x) - element, state = iterate(externed) - return element, (externed, state) -end -@inline function Base.iterate(::Thunk, (externed, state)) - element, new_state = iterate(externed, state) - return element, (externed, new_state) -end - -Base.conj(x::Thunk) = @thunk(conj(extern(x))) +(x::Thunk)() = x.f() +@inline extern(x::Thunk) = extern(x()) Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") @@ -259,7 +266,7 @@ but it should do this more efficently than simply doing this directly. Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; and destroy its inplacability. """ -struct InplaceableThunk{T<:Thunk, F} <: AbstractDifferential +struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk val::T add!::F end @@ -267,21 +274,13 @@ end (x::InplaceableThunk)() = x.val() @inline extern(x::InplaceableThunk) = extern(x.val) -Base.Broadcast.broadcastable(x::InplaceableThunk) = broadcastable(x.val) - -@inline function Base.iterate(x::InplaceableThunk, args...) - return iterate(x.val, args...) -end - -Base.conj(x::InplaceableThunk) = conj(x.val) - function Base.show(io::IO, x::InplaceableThunk) println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") end # The real reason we have this: accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) -store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero i, then add to it. +store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. """ NO_FIELDS From 765ecfc6a2764828ae4197a76f4cc1e770cc6db3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 10:07:24 +0100 Subject: [PATCH 16/23] rename InplaceableThunk InplaceThunk --- src/ChainRulesCore.jl | 2 +- src/differentials.jl | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 275029cb8..b51027559 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -5,7 +5,7 @@ export frule, rrule export wirtinger_conjugate, wirtinger_primal, differential export @scalar_rule, @thunk export extern, cast, store! -export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk +export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceThunk export NO_FIELDS include("differentials.jl") diff --git a/src/differentials.jl b/src/differentials.jl index b70600387..95a849a30 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -254,33 +254,33 @@ Base.conj(x::AbstractThunk) = @thunk(conj(extern(x))) Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") """ - InplaceableThunk(val::Thunk, add!::Function) + InplaceThunk(val::Thunk, add!::Function) A wrapper for a `Thunk`, that allows it to define an inplace `add!` function, -which is used internally in `accumulate!(Δ, ::InplaceableThunk)`. +which is used internally in `accumulate!(Δ, ::InplaceThunk)`. `add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val` but it should do this more efficently than simply doing this directly. (Otherwise one can just use a normal `Thunk`). -Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; +Most operations on an `InplaceThunk` treat it just like a normal `Thunk`; and destroy its inplacability. """ -struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk +struct InplaceThunk{T<:Thunk, F} <: AbstractThunk val::T add!::F end -(x::InplaceableThunk)() = x.val() -@inline extern(x::InplaceableThunk) = extern(x.val) +(x::InplaceThunk)() = x.val() +@inline extern(x::InplaceThunk) = extern(x.val) -function Base.show(io::IO, x::InplaceableThunk) - println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") +function Base.show(io::IO, x::InplaceThunk) + println(io, "InplaceThunk($(repr(x.val)), $(repr(x.add!)))") end # The real reason we have this: -accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) -store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. +accumulate!(Δ, ∂::InplaceThunk) = ∂.add!(Δ) +store!(Δ, ∂::InplaceThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. """ NO_FIELDS From 62d3be3bb2fc4aa3e5a8bd1bb2b93ba5c3e0da84 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 10:23:01 +0100 Subject: [PATCH 17/23] remove reference to Rules from docstrings --- src/operations.jl | 4 +-- src/rule_definition_tools.jl | 2 +- src/rules.jl | 55 +++++++++++++----------------------- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/src/operations.jl b/src/operations.jl index 33de43efa..c60134e9f 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -19,8 +19,8 @@ storing the result in `Δ`. Note: this function may not actually store the result in `Δ` if `Δ` is immutable, so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case. -This function is overloadable by [`InplaceableThunk`s](@ref). -See also: [`accumulate`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) +This function is overloadable by using a [`InplaceThunk`](@ref). +See also: [`accumulate`](@ref), [`store!`](@ref). """ function accumulate!(Δ, ∂) return materialize!(Δ, broadcastable(cast(Δ) + ∂)) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 5724412c5..ade48461b 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -151,7 +151,7 @@ is equivalent to: For examples, see ChainRulesCore' `rules` directory. -See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref) +See also: [`frule`](@ref), [`rrule`](@ref). """ macro scalar_rule(call, maybe_setup, partials...) ############################################################################ diff --git a/src/rules.jl b/src/rules.jl index 6a145fad7..9f891cc67 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -4,7 +4,7 @@ #= In some weird ideal sense, the fallback for e.g. `frule` should actually be "get -the derivative via forward-mode AD". This is necessary to enable mixed-mode +the derivative via forward-ode AD". This is necessary to enable mixed-mode rules, where e.g. `frule` is used within a `rrule` definition. For example, broadcasted functions may not themselves be forward-mode *primitives*, but are often forward-mode *differentiable*. @@ -33,7 +33,7 @@ my_frule(args...) = Cassette.overdub(MyChainRuleCtx(), frule, args...) function Cassette.execute(::MyChainRuleCtx, ::typeof(frule), f, x::Number) r = frule(f, x) if isa(r, Nothing) - fx, df = (f(x), Rule(Δx -> ForwardDiff.derivative(f, x) * Δx)) + fx, df = (f(x), (_, Δx) -> ForwardDiff.derivative(f, x) * Δx) else fx, df = r end @@ -48,16 +48,12 @@ end Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)` as `Ω`, return the tuple: - (Ω, (rule_for_ΔΩ₁::AbstractRule, rule_for_ΔΩ₂::AbstractRule, ...)) + (Ω, (ṡelf, ẋ₁, ẋ₂, ...) -> Ω̇₁, Ω̇₂, ...) -where each returned propagation rule `rule_for_ΔΩᵢ` can be invoked as +The second return value is the propagation rule, or the pushforward. +It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`) +and `ṡelf` the internal values of the function (for closures). - rule_for_ΔΩᵢ(Δx₁, Δx₂, ...) - -to yield `Ωᵢ`'s corresponding differential `ΔΩᵢ`. To illustrate, if all involved -values are real-valued scalars, this differential can be written as: - - ΔΩᵢ = ∂Ωᵢ_∂x₁ * Δx₁ + ∂Ωᵢ_∂x₂ * Δx₂ + ... If no method matching `frule(f, xs...)` has been defined, then return `nothing`. @@ -68,12 +64,12 @@ unary input, unary output scalar function: ``` julia> x = rand(); -julia> sinx, dsin = frule(sin, x); +julia> sinx, sin_pushforward = frule(sin, x); julia> sinx == sin(x) true -julia> dsin(1) == cos(x) +julia> sin_pushforward(NamedTuple(), 1) == cos(x) true ``` @@ -82,19 +78,16 @@ unary input, binary output scalar function: ``` julia> x = rand(); -julia> sincosx, (dsin, dcos) = frule(sincos, x); +julia> sincosx, sincos_pushforward = frule(sincos, x); julia> sincosx == sincos(x) true -julia> dsin(1) == cos(x) -true - -julia> dcos(1) == -sin(x) +julia> sincos_pushforward(NamedTuple(), 1) == (cos(x), -sin(x)) true ``` -See also: [`rrule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref) +See also: [`rrule`](@ref), [`@scalar_rule`](@ref) """ frule(::Any, ::Vararg{Any}; kwargs...) = nothing @@ -104,16 +97,11 @@ frule(::Any, ::Vararg{Any}; kwargs...) = nothing Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)` as `Ω`, return the tuple: - (Ω, (rule_for_Δx₁::AbstractRule, rule_for_Δx₂::AbstractRule, ...)) - -where each returned propagation rule `rule_for_Δxᵢ` can be invoked as + (Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...)) - rule_for_Δxᵢ(ΔΩ₁, ΔΩ₂, ...) - -to yield `xᵢ`'s corresponding differential `Δxᵢ`. To illustrate, if all involved -values are real-valued scalars, this differential can be written as: - - Δxᵢ = ∂Ω₁_∂xᵢ * ΔΩ₁ + ∂Ω₂_∂xᵢ * ΔΩ₂ + ... +Where the second return value is the the propagation rule or pullback. +It takes in differentials corresponding to the outputs (`x̄₁, x̄₂, ...`), +and `s̄elf`, the internal values of the function itself (for closures) If no method matching `rrule(f, xs...)` has been defined, then return `nothing`. @@ -124,12 +112,12 @@ unary input, unary output scalar function: ``` julia> x = rand(); -julia> sinx, dx = rrule(sin, x); +julia> sinx, sin_pullback = rrule(sin, x); julia> sinx == sin(x) true -julia> dx(1) == cos(x) +julia> sin_pullback(1) == (NO_FIELDS, cos(x)) true ``` @@ -138,18 +126,15 @@ binary input, unary output scalar function: ``` julia> x, y = rand(2); -julia> hypotxy, (dx, dy) = rrule(hypot, x, y); +julia> hypotxy, hypot_pullback = rrule(hypot, x, y); julia> hypotxy == hypot(x, y) true -julia> dx(1) == (x / hypot(x, y)) -true - -julia> dy(1) == (y / hypot(x, y)) +julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y))) true ``` -See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref) +See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ rrule(::Any, ::Vararg{Any}; kwargs...) = nothing From 3ae64494573b4e1a78376fba9ccf5ec7e3980978 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 10:27:33 +0100 Subject: [PATCH 18/23] update thunk docstring --- src/differentials.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/differentials.jl b/src/differentials.jl index 95a849a30..9c0875920 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -201,7 +201,6 @@ A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a differential. `@thunk(v)` is a macro that expands into `Thunk(()->v)`. - Calling a thunk, calls the wrapped closure. `extern`ing thunks applies recursively, it also externs the differial that the closure returns. If you do not want that, then simply call the thunk @@ -222,19 +221,19 @@ julia> t()() ### When to `@thunk`? When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk` -appropriately. Propagation rule's that return multiple deriviatives are able to not -do all the work of computing all of them only to have just one used. -By `@thunk`ing the work required for each, they can only be computed when needed. +appropriately. +Propagation rule's that return multiple derivatives are not able to do all the computing themselves. + By `@thunk`ing the work required for each, they then compute only what is needed. #### So why not thunk everything? -`@thunk` creates a closure over the expression, which is basically a struct with a -field for each variable used in the expression (closed over), and call overloaded. -If this would be equal or more work than actually evaluating the expression then don't do -it. An example would be if the expression itself is just wrapping something in a struct. -Such as `Adjoint(x)` or `Diagonal(x)`. Or if the expression is a constant, or is -itself a `Thunk`. -If you got the expression from another `rrule` (or `frule`), you don't need to -`@thunk` it since it will have been thunked if required, by the defining rule. +`@thunk` creates a closure over the expression, which (effectively) creates a `struct` +with a field for each variable used in the expression, and call overloaded. + +Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being: +- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)` +- The expression being a constant +- The expression being itself a `thunk` +- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already) """ struct Thunk{F} <: AbstractThunk f::F From 03cb994b0ef43fccdafa0b276facb5e317321903 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 10:28:56 +0100 Subject: [PATCH 19/23] fix indent in docstring --- src/rule_definition_tools.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index ade48461b..ad0ac0312 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -25,13 +25,13 @@ propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.valu """ propagation_expr(𝒟, Δs, ∂s) - Returns the expression for the propagation of - the input gradient `Δs` though the partials `∂s`. +Returns the expression for the propagation of +the input gradient `Δs` though the partials `∂s`. - 𝒟 is an expression that when evaluated returns the type-of the input domain. - For example if the derivative is being taken at the point `1` it returns `Int`. - if it is taken at `1+1im` it returns `Complex{Int}`. - At present it is ignored for non-Wirtinger derivatives. +𝒟 is an expression that when evaluated returns the type-of the input domain. +For example if the derivative is being taken at the point `1` it returns `Int`. +if it is taken at `1+1im` it returns `Complex{Int}`. +At present it is ignored for non-Wirtinger derivatives. """ function propagation_expr(𝒟, Δs, ∂s) wirtinger_indices = findall(∂s) do ex From cb017433a2a479442e13ac1cbbf26a97ea345ff1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 10:07:24 +0100 Subject: [PATCH 20/23] Revert "rename InplaceableThunk InplaceThunk" This reverts commit 85b5bf9899e3d209c86b8583343d041640ee6344. --- src/ChainRulesCore.jl | 2 +- src/differentials.jl | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index b51027559..275029cb8 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -5,7 +5,7 @@ export frule, rrule export wirtinger_conjugate, wirtinger_primal, differential export @scalar_rule, @thunk export extern, cast, store! -export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceThunk +export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk export NO_FIELDS include("differentials.jl") diff --git a/src/differentials.jl b/src/differentials.jl index 9c0875920..f3c063b9c 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -253,33 +253,33 @@ Base.conj(x::AbstractThunk) = @thunk(conj(extern(x))) Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") """ - InplaceThunk(val::Thunk, add!::Function) + InplaceableThunk(val::Thunk, add!::Function) A wrapper for a `Thunk`, that allows it to define an inplace `add!` function, -which is used internally in `accumulate!(Δ, ::InplaceThunk)`. +which is used internally in `accumulate!(Δ, ::InplaceableThunk)`. `add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val` but it should do this more efficently than simply doing this directly. (Otherwise one can just use a normal `Thunk`). -Most operations on an `InplaceThunk` treat it just like a normal `Thunk`; +Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; and destroy its inplacability. """ -struct InplaceThunk{T<:Thunk, F} <: AbstractThunk +struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk val::T add!::F end -(x::InplaceThunk)() = x.val() -@inline extern(x::InplaceThunk) = extern(x.val) +(x::InplaceableThunk)() = x.val() +@inline extern(x::InplaceableThunk) = extern(x.val) -function Base.show(io::IO, x::InplaceThunk) - println(io, "InplaceThunk($(repr(x.val)), $(repr(x.add!)))") +function Base.show(io::IO, x::InplaceableThunk) + println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") end # The real reason we have this: -accumulate!(Δ, ∂::InplaceThunk) = ∂.add!(Δ) -store!(Δ, ∂::InplaceThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. +accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) +store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. """ NO_FIELDS From 3656389baf091bd1812cae9ca0b632824e2f9200 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 11:02:32 +0100 Subject: [PATCH 21/23] rename differential to refine_differential --- src/ChainRulesCore.jl | 2 +- src/differentials.jl | 6 +++--- src/rule_definition_tools.jl | 2 +- test/differentials.jl | 12 ++++++------ 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 275029cb8..118e7f841 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable export frule, rrule -export wirtinger_conjugate, wirtinger_primal, differential +export wirtinger_conjugate, wirtinger_primal, refine_differential export @scalar_rule, @thunk export extern, cast, store! export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk diff --git a/src/differentials.jl b/src/differentials.jl index f3c063b9c..5ad2f8818 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -291,14 +291,14 @@ function itself, when that function is not a closure. const NO_FIELDS = DNE() """ - differential(𝒟::Type, der) + refine_differential(𝒟::Type, der) Converts, if required, a differential object `der` (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), to another differential that is more suited for the domain given by the type 𝒟. Often this will behave as the identity function on `der`. """ -function differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) +function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) return wirtinger_primal(w) + wirtinger_conjugate(w) end -differential(::Any, der) = der # most of the time leave it alone. +refine_differential(::Any, der) = der # most of the time leave it alone. diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index ad0ac0312..202b9490c 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -81,7 +81,7 @@ function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) return quote # This will be a block, so will have value equal to last statement $(∂_wirtinger_defs...) w = Wirtinger($primal_sum, $conjugate_sum) - differential($𝒟, w) + refine_differential($𝒟, w) end end diff --git a/test/differentials.jl b/test/differentials.jl index 73c76da83..570b09d88 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -81,17 +81,17 @@ end - @testset "Differential" begin - @test differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2) - @test differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2) + @testset "Refine Differential" begin + @test refine_differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2) + @test refine_differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2) - @test differential(typeof(1.2), Wirtinger(2,2)) == 4 - @test differential(typeof([1.2]), Wirtinger(2,2)) == 4 + @test refine_differential(typeof(1.2), Wirtinger(2,2)) == 4 + @test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4 # For most differentials, in most domains, this does nothing for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2])) - @test differential(𝒟, der) === der + @test refine_differential(𝒟, der) === der end end end From 1869be1ecd3a8266193a8b60c6239a0567fe2755 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 16:55:31 +0100 Subject: [PATCH 22/23] Update src/rules.jl Co-Authored-By: Nick Robinson --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 9f891cc67..371013ced 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -4,7 +4,7 @@ #= In some weird ideal sense, the fallback for e.g. `frule` should actually be "get -the derivative via forward-ode AD". This is necessary to enable mixed-mode +the derivative via forward-mode AD". This is necessary to enable mixed-mode rules, where e.g. `frule` is used within a `rrule` definition. For example, broadcasted functions may not themselves be forward-mode *primitives*, but are often forward-mode *differentiable*. From e51ff806d05a389eebb8e1a10e796bfc52ca9481 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 17 Sep 2019 18:45:09 +0100 Subject: [PATCH 23/23] split up scalar_rule into a bunch of functions --- src/rule_definition_tools.jl | 310 +++++++++++++++++++---------------- 1 file changed, 170 insertions(+), 140 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 202b9490c..a06820e64 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,90 +1,5 @@ # These are some macros (and supporting functions) to make it easier to define rules. -""" - propagator_name(f, propname) - -Determines a reasonable name for the propagator function. -The name doesn't really matter too much as it is a local function to be returned -by `frule` or `rrule`, but a good name make debugging easier. -`f` should be some form of AST representation of the actual function, -`propname` should be either `:pullback` or `:pushforward` - -This is able to deal with fairly complex expressions for `f`: - - julia> propagator_name(:bar, :pushforward) - :bar_pushforward - - julia> propagator_name(esc(:(Base.Random.foo)), :pullback) - :foo_pullback -""" -propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname) -propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) -propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) - - -""" - propagation_expr(𝒟, Δs, ∂s) - -Returns the expression for the propagation of -the input gradient `Δs` though the partials `∂s`. - -𝒟 is an expression that when evaluated returns the type-of the input domain. -For example if the derivative is being taken at the point `1` it returns `Int`. -if it is taken at `1+1im` it returns `Complex{Int}`. -At present it is ignored for non-Wirtinger derivatives. -""" -function propagation_expr(𝒟, Δs, ∂s) - wirtinger_indices = findall(∂s) do ex - Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger - end - ∂s = map(esc, ∂s) - if isempty(wirtinger_indices) - return standard_propagation_expr(Δs, ∂s) - else - return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) - end -end - -function standard_propagation_expr(Δs, ∂s) - # This is basically Δs ⋅ ∂s - - # Notice: the thunking of `∂s[i] (potentially) saves us some computation - # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon - # as the pullback is evaluated - ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] - return :(+($(∂_mul_Δs...))) -end - -function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) - ∂_mul_Δs_primal = Any[] - ∂_mul_Δs_conjugate = Any[] - ∂_wirtinger_defs = Any[] - for i in 1:length(∂s) - if i in wirtinger_indices - Δi = Δs[i] - ∂i = Symbol(string(:∂, i)) - push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) - ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) - ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) - ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) - ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) - push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) - push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) - else - ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) - push!(∂_mul_Δs_primal, ∂_mul_Δ) - push!(∂_mul_Δs_conjugate, ∂_mul_Δ) - end - end - primal_sum = :(+($(∂_mul_Δs_primal...))) - conjugate_sum = :(+($(∂_mul_Δs_conjugate...))) - return quote # This will be a block, so will have value equal to last statement - $(∂_wirtinger_defs...) - w = Wirtinger($primal_sum, $conjugate_sum) - refine_differential($𝒟, w) - end -end - """ @scalar_rule(f(x₁, x₂, ...), @setup(statement₁, statement₂, ...), @@ -151,9 +66,49 @@ is equivalent to: For examples, see ChainRulesCore' `rules` directory. -See also: [`frule`](@ref), [`rrule`](@ref). +See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref) """ macro scalar_rule(call, maybe_setup, partials...) + call, setup_stmts, inputs, partials = _normalize_scalarrules_macro_input( + call, maybe_setup, partials + ) + f = call.args[1] + + # An expression that when evaluated will return the type of the input domain. + # Multiple repetitions of this expression should optimize out. But if it does not then + # may need to move its definition into the body of the `rrule`/`frule` + 𝒟 = :(typeof(first(promote($(call.args[2:end]...))))) + + frule_expr = scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials) + rrule_expr = scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials) + + + ############################################################################ + # Final return: building the expression to insert in the place of this macro + code = quote + if !($f isa Type) && fieldcount(typeof($f)) > 0 + throw(ArgumentError( + "@scalar_rule cannot be used on closures/functors (such as $($f))" + )) + end + + $(frule_expr) + $(rrule_expr) + end +end + + +""" + _normalize_scalarrules_macro_input(call, maybe_setup, partials) + +returns (in order) the correctly escaped: + - `call` with out any type constraints + - `setup_stmts`: the content of `@setup` or `nothing` if that is not provided, + - `inputs`: with all args having the constraints removed from call, or + defaulting to `Number` + - `partials`: which are all `Expr{:tuple,...}` +""" +function _normalize_scalarrules_macro_input(call, maybe_setup, partials) ############################################################################ # Setup: normalizing input form etc @@ -164,12 +119,12 @@ macro scalar_rule(call, maybe_setup, partials...) partials = (maybe_setup, partials...) end @assert Meta.isexpr(call, :call) - f = esc(call.args[1]) # Annotate all arguments in the signature as scalars inputs = map(call.args[2:end]) do arg esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number)) end + # Remove annotations and escape names for the call for (i, arg) in enumerate(call.args) if Meta.isexpr(arg, :(::)) @@ -189,74 +144,65 @@ macro scalar_rule(call, maybe_setup, partials...) end end - ############################################################################ - # Main body: defining the results of the frule/rrule - - # An expression that when evaluated will return the type of the input domain. - # Multiple repetitions of this expression should optimize out. But if it does not then - # may need to move its definition into the body of the `rrule`/`frule` - 𝒟 = :(typeof(first(promote($(call.args[2:end]...))))) + return call, setup_stmts, inputs, partials +end +function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials) n_outputs = length(partials) n_inputs = length(inputs) - pushforward = let - # Δs is the input to the propagator rule - # because this is push-forward there is one per input to the function - Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] - pushforward_returns = map(1:n_outputs) do output_i - ∂s = partials[output_i].args - propagation_expr(𝒟, Δs, ∂s) - end - if n_outputs > 1 - # For forward-mode we only return a tuple if output actually a tuple. - pushforward_returns = Expr(:tuple, pushforward_returns...) - else - pushforward_returns = pushforward_returns[1] - end - quote - # _ is the input derivative w.r.t. function internals. since we do not - # allow closures/functors with @scalar_rule, it is always ignored - function $(propagator_name(f, :pushforward))(_, $(Δs...)) - $pushforward_returns - end - end + # Δs is the input to the propagator rule + # because this is push-forward there is one per input to the function + Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] + pushforward_returns = map(1:n_outputs) do output_i + ∂s = partials[output_i].args + propagation_expr(𝒟, Δs, ∂s) end - - pullback = let - # Δs is the input to the propagator rule - # because this is a pull-back there is one per output of function - Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs] - - # 1 partial derivative per input - pullback_returns = map(1:n_inputs) do input_i - ∂s = [partial.args[input_i] for partial in partials] - propagation_expr(𝒟, Δs, ∂s) - end - - quote - function $(propagator_name(f, :pullback))($(Δs...)) - return (NO_FIELDS, $(pullback_returns...)) - end - end + if n_outputs > 1 + # For forward-mode we only return a tuple if output actually a tuple. + pushforward_returns = Expr(:tuple, pushforward_returns...) + else + pushforward_returns = pushforward_returns[1] end - ############################################################################ - # Final return: building the expression to insert in the place of this macro - - code = quote - if fieldcount(typeof($f)) > 0 - throw(ArgumentError( - "@scalar_rule cannot be used on closures/functors (such as $f)" - )) + pushforward = quote + # _ is the input derivative w.r.t. function internals. since we do not + # allow closures/functors with @scalar_rule, it is always ignored + function $(propagator_name(f, :pushforward))(_, $(Δs...)) + $pushforward_returns end + end + return quote function ChainRulesCore.frule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $pushforward end + end +end + +function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials) + n_outputs = length(partials) + n_inputs = length(inputs) + + # Δs is the input to the propagator rule + # because this is a pull-back there is one per output of function + Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs] + + # 1 partial derivative per input + pullback_returns = map(1:n_inputs) do input_i + ∂s = [partial.args[input_i] for partial in partials] + propagation_expr(𝒟, Δs, ∂s) + end + + pullback = quote + function $(propagator_name(f, :pullback))($(Δs...)) + return (NO_FIELDS, $(pullback_returns...)) + end + end + return quote function ChainRulesCore.rrule(::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) @@ -264,3 +210,87 @@ macro scalar_rule(call, maybe_setup, partials...) end end end + +""" + propagation_expr(𝒟, Δs, ∂s) + + Returns the expression for the propagation of + the input gradient `Δs` though the partials `∂s`. + + 𝒟 is an expression that when evaluated returns the type-of the input domain. + For example if the derivative is being taken at the point `1` it returns `Int`. + if it is taken at `1+1im` it returns `Complex{Int}`. + At present it is ignored for non-Wirtinger derivatives. +""" +function propagation_expr(𝒟, Δs, ∂s) + wirtinger_indices = findall(∂s) do ex + Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger + end + ∂s = map(esc, ∂s) + if isempty(wirtinger_indices) + return standard_propagation_expr(Δs, ∂s) + else + return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) + end +end + +function standard_propagation_expr(Δs, ∂s) + # This is basically Δs ⋅ ∂s + + # Notice: the thunking of `∂s[i] (potentially) saves us some computation + # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon + # as the pullback is evaluated + ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] + return :(+($(∂_mul_Δs...))) +end + +function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) + ∂_mul_Δs_primal = Any[] + ∂_mul_Δs_conjugate = Any[] + ∂_wirtinger_defs = Any[] + for i in 1:length(∂s) + if i in wirtinger_indices + Δi = Δs[i] + ∂i = Symbol(string(:∂, i)) + push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) + ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) + ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) + ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) + ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) + push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) + push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) + else + ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) + push!(∂_mul_Δs_primal, ∂_mul_Δ) + push!(∂_mul_Δs_conjugate, ∂_mul_Δ) + end + end + primal_sum = :(+($(∂_mul_Δs_primal...))) + conjugate_sum = :(+($(∂_mul_Δs_conjugate...))) + return quote # This will be a block, so will have value equal to last statement + $(∂_wirtinger_defs...) + w = Wirtinger($primal_sum, $conjugate_sum) + refine_differential($𝒟, w) + end +end + +""" + propagator_name(f, propname) + +Determines a reasonable name for the propagator function. +The name doesn't really matter too much as it is a local function to be returned +by `frule` or `rrule`, but a good name make debugging easier. +`f` should be some form of AST representation of the actual function, +`propname` should be either `:pullback` or `:pushforward` + +This is able to deal with fairly complex expressions for `f`: + + julia> propagator_name(:bar, :pushforward) + :bar_pushforward + + julia> propagator_name(esc(:(Base.Random.foo)), :pullback) + :foo_pullback +""" +propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname) +propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) +propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)