From f37d6a3c2223d36c648a481179fa2c74bcad2b72 Mon Sep 17 00:00:00 2001 From: piever Date: Wed, 8 Sep 2021 13:03:38 +0200 Subject: [PATCH 1/7] add derivatives_given_output --- src/rule_definition_tools.jl | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 10912ce61..453b1d7dc 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -88,8 +88,15 @@ macro scalar_rule(call, maybe_setup, partials...) ) f = call.args[1] - frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) - rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) + # Generate variables to store derivatives named + derivatives = map(keys(partials)) do i + syms = map(j -> gensym("df$(i)/dx$(j)"), keys(inputs)) + return Expr(:tuple, syms...) + end + + derivative_expr = scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) + frule_expr = scalar_frule_expr(__source__, f, call, [], inputs, derivatives) + rrule_expr = scalar_rrule_expr(__source__, f, call, [], inputs, derivatives) # Final return: building the expression to insert in the place of this macro code = quote @@ -99,6 +106,7 @@ macro scalar_rule(call, maybe_setup, partials...) )) end + $(derivative_expr) $(frule_expr) $(rrule_expr) end @@ -145,6 +153,26 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) return call, setup_stmts, inputs, partials end +""" + derivatives_given_output(Ω, f, xs...) + +Compute the derivative of scalar function `f` with inputs `xs...` and output `Ω`. +This is used within the implementation of [`@scalar_rule`](@ref) and is not +considered part of the stable API. +If the output is scalar, return a tuple with partial derivatives with respect to the `xs`. +If the output is a tuple, return a tuple of tuples. +""" +function derivatives_given_output end + +function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) + return @strip_linenos quote + function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) + $(__source__) + $(setup_stmts...) + return $(esc(Expr(:tuple, partials...))) + end + end +end function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) n_outputs = length(partials) @@ -173,6 +201,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) + $(esc(Expr(:tuple, partials...))) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -210,6 +239,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) + $(esc(Expr(:tuple, partials...))) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end From 6d29b7ceab953b2c810f4aed84aaecb3247aeb75 Mon Sep 17 00:00:00 2001 From: piever Date: Wed, 8 Sep 2021 13:20:08 +0200 Subject: [PATCH 2/7] esc fixes --- src/rule_definition_tools.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 453b1d7dc..e29f41cbd 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -88,9 +88,9 @@ macro scalar_rule(call, maybe_setup, partials...) ) f = call.args[1] - # Generate variables to store derivatives named + # Generate variables to store derivatives named dfi/dxj derivatives = map(keys(partials)) do i - syms = map(j -> gensym("df$(i)/dx$(j)"), keys(inputs)) + syms = map(j -> esc(gensym(Symbol("df", i, "/dx", j))), keys(inputs)) return Expr(:tuple, syms...) end @@ -143,10 +143,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) # For consistency in code that follows we make all partials tuple expressions partials = map(partials) do partial if Meta.isexpr(partial, :tuple) - partial + Expr(:tuple, map(esc, partial.args)...) else length(inputs) == 1 || error("Invalid use of `@scalar_rule`") - Expr(:tuple, partial) + Expr(:tuple, esc(partial)) end end @@ -169,7 +169,7 @@ function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) $(__source__) $(setup_stmts...) - return $(esc(Expr(:tuple, partials...))) + return $(Expr(:tuple, partials...)) end end end @@ -201,7 +201,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(esc(Expr(:tuple, partials...))) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pushforward_returns end end @@ -239,7 +239,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) - $(esc(Expr(:tuple, partials...))) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc(:Ω)), $f, $(inputs...)) return $(esc(:Ω)), $pullback end end @@ -270,9 +270,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj - :(conj($(esc(∂s_i)))) + :(conj($∂s_i)) else - esc(∂s_i) + ∂s_i end end From 69e1e4f1d2df00cf41f9d5f74caaca99c9ace459 Mon Sep 17 00:00:00 2001 From: piever Date: Wed, 8 Sep 2021 14:32:45 +0200 Subject: [PATCH 3/7] test derivatives_given_output and correct docs --- src/rule_definition_tools.jl | 5 +++-- test/rule_definition_tools.jl | 7 ++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index e29f41cbd..7239cecd9 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -159,8 +159,9 @@ end Compute the derivative of scalar function `f` with inputs `xs...` and output `Ω`. This is used within the implementation of [`@scalar_rule`](@ref) and is not considered part of the stable API. -If the output is scalar, return a tuple with partial derivatives with respect to the `xs`. -If the output is a tuple, return a tuple of tuples. +It returns a tuple of tuples with the partial derivatives of `f` with respect to the `xs`. +The derivative of the `i`-th component of `f` with respect to the `j`-th input can be accessed +as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. """ function derivatives_given_output end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index f30233bfb..0d6d98535 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -235,6 +235,9 @@ end @test ẏ == Tangent{typeof(y)}(50f0, 100f0) # make sure type is exactly as expected: @test ẏ isa Tangent{Tuple{Irrational{:π}, Float64}, Tuple{Float32, Float32}} + + xs, Ω = (3,), (3, 6) + @test ChainRulesCore.derivatives_given_output(Ω, simo, xs...) == ((1f0,), (2f0,)) end @testset "@scalar_rule projection" begin @@ -298,7 +301,7 @@ module IsolatedModuleForTestingScoping module IsolatedSubmodule # check that rules defined in isolated module without imports can be called # without errors - using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent + using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id using Test @@ -328,6 +331,8 @@ module IsolatedModuleForTestingScoping y, f_pullback = rrule(my_id, x) @test y == x @test f_pullback(Δy) == (NoTangent(), Δy) + + @test derivatives_given_output(y, my_id, x) == ((1.0,),) end end end From 4fd98d4b82837a3fd5b616e4d9a1539b611f0b81 Mon Sep 17 00:00:00 2001 From: Pietro Vertechi Date: Wed, 8 Sep 2021 18:49:05 +0200 Subject: [PATCH 4/7] Update src/rule_definition_tools.jl Co-authored-by: Lyndon White --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 7239cecd9..0d10d27fe 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -156,7 +156,7 @@ end """ derivatives_given_output(Ω, f, xs...) -Compute the derivative of scalar function `f` with inputs `xs...` and output `Ω`. +Compute the derivative of scalar function `f` at primal input point `xs...`, given that it had primal output `Ω`. This is used within the implementation of [`@scalar_rule`](@ref) and is not considered part of the stable API. It returns a tuple of tuples with the partial derivatives of `f` with respect to the `xs`. From 7d99bb68c5ed59befff05b51dc955ce12bc25927 Mon Sep 17 00:00:00 2001 From: piever Date: Wed, 8 Sep 2021 18:53:16 +0200 Subject: [PATCH 5/7] add experimental warning --- src/rule_definition_tools.jl | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 0d10d27fe..57cb35430 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -156,12 +156,20 @@ end """ derivatives_given_output(Ω, f, xs...) -Compute the derivative of scalar function `f` at primal input point `xs...`, given that it had primal output `Ω`. -This is used within the implementation of [`@scalar_rule`](@ref) and is not -considered part of the stable API. -It returns a tuple of tuples with the partial derivatives of `f` with respect to the `xs`. -The derivative of the `i`-th component of `f` with respect to the `j`-th input can be accessed -as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. +Compute the derivative of scalar function `f` at primal input point `xs...`, +given that it had primal output `Ω`. +Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`. +The derivative of the `i`-th component of `f` with respect to the `j`-th input can be +accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. + +!!! warning "Experimental" + This function is experimental and not part of the stable API. + At the moment, it can be considered an implementation detail of the macro + [`@scalar_rule`](@ref), in which it is used. + In the future, the exact semantics of this function will stabilize, and it + will be added to the stable API. + When that happens, this warning will be removed. + """ function derivatives_given_output end From 2339e05bacdc86e2ac81f70d39fd0d844e33ef01 Mon Sep 17 00:00:00 2001 From: Pietro Vertechi Date: Wed, 8 Sep 2021 19:42:54 +0200 Subject: [PATCH 6/7] Update src/rule_definition_tools.jl Co-authored-by: Lyndon White --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 57cb35430..636bb26a5 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -90,7 +90,7 @@ macro scalar_rule(call, maybe_setup, partials...) # Generate variables to store derivatives named dfi/dxj derivatives = map(keys(partials)) do i - syms = map(j -> esc(gensym(Symbol("df", i, "/dx", j))), keys(inputs)) + syms = map(j -> Symbol("∂f", i, "/∂x", j)), keys(inputs) return Expr(:tuple, syms...) end From d8e8f63430b5368f8ce0e1775ddccbf7acacea6a Mon Sep 17 00:00:00 2001 From: piever Date: Wed, 8 Sep 2021 19:48:49 +0200 Subject: [PATCH 7/7] fix typo --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 636bb26a5..911a32ddd 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -90,7 +90,7 @@ macro scalar_rule(call, maybe_setup, partials...) # Generate variables to store derivatives named dfi/dxj derivatives = map(keys(partials)) do i - syms = map(j -> Symbol("∂f", i, "/∂x", j)), keys(inputs) + syms = map(j -> Symbol("∂f", i, "/∂x", j), keys(inputs)) return Expr(:tuple, syms...) end