Currently, the approach in ReverseAD is to generate the symbolic expression of the first and second-order derivatives for classical univariate functions using Calculus
https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/univariate_expressions_generator.jl
Then, given a representation of the operator as an Int, we do an hard-coded binary search to evaluate a O(log(n)) number of Int comparison instead of a O(n) number of comparison:
https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/operators.jl#L570-L582
I'm wondering whether we could get closer to ChainRules instead like other Julia AD framework.
The naive way to do this would be
op = :tanh
f = eval(op)
value_and_derivative(f, 1)
The issue is that, because the value of op is discovered at run-time, the type of f is type-unstable.
But we can use the same trick with the if-else and do
if op == :tanh
value_and_derivative(tanh, x)
elseif op == :tan
value_and_derivative(tan, x)
elseif ...
else
value_and_derivative(eval(op), x)
end
Again, we can do a binary search instead of just a list of if-else.
So, for a fixed number of symbols, we avoid the type-instability thanks to the if-else and we have a fallback for the other ones with the eval.
That would also mean that for registered functions, we need to implement a method and just rely on multiple dispatch instead of adding an operators to the list of user-defined operators, user-defined operators already trigger a type-instability when they are called anyway.
Currently, the approach in ReverseAD is to generate the symbolic expression of the first and second-order derivatives for classical univariate functions using Calculus
https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/univariate_expressions_generator.jl
Then, given a representation of the operator as an
Int, we do an hard-coded binary search to evaluate aO(log(n))number ofIntcomparison instead of aO(n)number of comparison:https://github.com/jump-dev/MathOptInterface.jl/blob/100eab2e669e73689e1dc214391d97c24402e35c/src/Nonlinear/operators.jl#L570-L582
I'm wondering whether we could get closer to ChainRules instead like other Julia AD framework.
The naive way to do this would be
The issue is that, because the value of
opis discovered at run-time, the type offis type-unstable.But we can use the same trick with the
if-elseand doAgain, we can do a binary search instead of just a list of
if-else.So, for a fixed number of symbols, we avoid the type-instability thanks to the
if-elseand we have a fallback for the other ones with theeval.That would also mean that for registered functions, we need to implement a method and just rely on multiple dispatch instead of adding an operators to the list of user-defined operators, user-defined operators already trigger a type-instability when they are called anyway.