diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 10912ce61..911a32ddd 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -88,8 +88,15 @@ macro scalar_rule(call, maybe_setup, partials...) ) f = call.args[1] - frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) - rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) + # Generate variables to store derivatives named dfi/dxj + derivatives = map(keys(partials)) do i + syms = map(j -> Symbol("∂f", i, "/∂x", j), keys(inputs)) + return Expr(:tuple, syms...) + end + + derivative_expr = scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) + frule_expr = scalar_frule_expr(__source__, f, call, [], inputs, derivatives) + rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) # Final return: building the expression to insert in the place of this macro code = quote @@ -99,6 +106,7 @@ macro scalar_rule(call, maybe_setup, partials...) )) end + $(derivative_expr) $(frule_expr) $(rrule_expr) end @@ -135,16 +143,45 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) # For consistency in code that follows we make all partials tuple expressions partials = map(partials) do partial if Meta.isexpr(partial, :tuple) - partial + Expr(:tuple, map(esc, partial.args)...) else length(inputs) == 1 || error("Invalid use of `@scalar_rule`") - Expr(:tuple, partial) + Expr(:tuple, esc(partial)) end end return call, setup_stmts, inputs, partials end +""" + derivatives_given_output(Ω, f, xs...) + +Compute the derivative of scalar function `f` at primal input point `xs...`, +given that it had primal output `Ω`. +Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`. +The derivative of the `i`-th component of `f` with respect to the `j`-th input can be +accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. + +!!! warning "Experimental" + This function is experimental and not part of the stable API. + At the moment, it can be considered an implementation detail of the macro + [`@scalar_rule`](@ref), in which it is used. + In the future, the exact semantics of this function will stabilize, and it + will be added to the stable API. + When that happens, this warning will be removed. + +""" +function derivatives_given_output end + +function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) + return @strip_linenos quote + function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) + $(__source__) + $(setup_stmts...) + return $(Expr(:tuple, partials...)) + end + end +end function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) n_outputs = length(partials) @@ -173,6 +210,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -210,6 +248,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end @@ -240,9 +279,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj - :(conj($(esc(∂s_i)))) + :(conj($∂s_i)) else - esc(∂s_i) + ∂s_i end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index f30233bfb..0d6d98535 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -235,6 +235,9 @@ end @test ẏ == Tangent{typeof(y)}(50f0, 100f0) # make sure type is exactly as expected: @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} + + xs, Ω = (3,), (3, 6) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) end @testset "@scalar_rule projection" begin @@ -298,7 +301,7 @@ module IsolatedModuleForTestingScoping module IsolatedSubmodule # check that rules defined in isolated module without imports can be called # without errors - using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent + using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id using Test @@ -328,6 +331,8 @@ module IsolatedModuleForTestingScoping y, f_pullback = rrule(my_id, x) @test y == x @test f_pullback(Δy) == (NoTangent(), Δy) + + @test derivatives_given_output(y, my_id, x) == ((1.0,),) end end end