diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 911a32ddd..8d77955d1 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -158,6 +158,7 @@ end Compute the derivative of scalar function `f` at primal input point `xs...`, given that it had primal output `Ω`. + Return a tuple of tuples with the partial derivatives of `f` with respect to the `xs...`. The derivative of the `i`-th component of `f` with respect to the `j`-th input can be accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. @@ -173,16 +174,43 @@ accessed as `Df[i][j]`, where `Df = derivatives_given_output(Ω, f, xs...)`. """ function derivatives_given_output end +""" + derivatives_given_output(f, xs...) + +Compute the derivative of scalar function `f` at primal input point `xs...`, +when this is possible *without* knowing primal output `Ω`. + +!!! warning "Experimental" +""" +function derivatives_given_input end + function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials) - return @strip_linenos quote - function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) + given_output = @strip_linenos quote + @inline function ChainRulesCore.derivatives_given_output($(esc(:Ω)), ::Core.Typeof($f), $(inputs...)) $(__source__) $(setup_stmts...) return $(Expr(:tuple, partials...)) end end + given_input = @strip_linenos quote + @inline function ChainRulesCore.derivatives_given_input(::Core.Typeof($f), $(inputs...)) + $(__source__) + $(setup_stmts...) + return $(Expr(:tuple, partials...)) + end + end + return if _free_of_omega([inputs, partials]) + :($given_output; $given_input) + else + given_output + end end +_free_of_omega(v::Union{Vector,Tuple}) = all(_free_of_omega, v) +_free_of_omega(ex::Expr) = _free_of_omega(ex.args) +_free_of_omega(s::Symbol) = s != :Ω +_free_of_omega(other) = true # (@show other typeof(other); true) + function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) n_outputs = length(partials) n_inputs = length(inputs)