Improve DiffRules integration and tests#209
Conversation
Codecov ReportBase: 85.16% // Head: 81.24% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #209 +/- ##
==========================================
- Coverage 85.16% 81.24% -3.93%
==========================================
Files 18 18
Lines 1861 1578 -283
==========================================
- Hits 1585 1282 -303
- Misses 276 296 +20
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
|
Bump 🙂 It would be good to fix ReverseDiff such that we can move forward with JuliaDiff/DiffRules.jl#79. |
|
Sorry for the delay, been swamped recently. I will take a look tonight. |
mohdibntarek
left a comment
There was a problem hiding this comment.
Do we have tests for ForwardOptimize where both x and y are tracked? Seems there might be a method ambiguity error in this case?
|
Ah the methods are defined a bit further in the same file, nevermind. |
| @eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::TrackedArray{X}, y::$A) where {F,X} | ||
| result = DiffResults.GradientResult(SVector(zero(X))) | ||
| df = (vx, vy) -> let vy=vy | ||
| ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx)) |
There was a problem hiding this comment.
Since s is a SVector with a single element vx which we want to use here. That's just the one-argument version of the current implementation on the master branch:
ReverseDiff.jl/src/derivatives/elementwise.jl
Lines 116 to 117 in d522508
| @eval function Base.$(g!)(f::ForwardOptimize{F}, out::TrackedArray, x::$A, y::TrackedArray{Y}) where {F,Y} | ||
| result = DiffResults.GradientResult(SVector(zero(Y))) | ||
| df = let vx=vx | ||
| (vx, vy) -> ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy)) |
| @eval function Base.$(g)(f::ForwardOptimize{F}, x::TrackedArray{X,D}, y::$A) where {F,X,D} | ||
| result = DiffResults.GradientResult(SVector(zero(X))) | ||
| df = (vx, vy) -> let vy=vy | ||
| ForwardDiff.gradient!(result, s -> f.f(s[1], vy), SVector(vx)) |
| istracked(b) && diffresult_increment_deriv!(b, output_deriv, results, 2, b_bound) | ||
| if istracked(a) | ||
| p += 1 | ||
| diffresult_increment_deriv!(a, output_deriv, results, p, a_bound) |
There was a problem hiding this comment.
why change the value of p here?
There was a problem hiding this comment.
To extract the correct partial. If a is tracked, its corresponding partial has index p = 1 but if only b is tracked, the first partial (p = 1) corresponds to b. And if both a and b are tracked, p = 1 corresponds to a and p = 2 to b. So incrementing p in the branches allows us to avoid checking and handling all three scenarios separately.
Note that on the master branch p = 1 for a and p = 2 for b are hardcoded. That only works because on the master branch always the partials wrt to both arguments are computed and stored, even if only one argument is tracked.
| end | ||
| if istracked(b) | ||
| p += 1 | ||
| diffresult_increment_deriv!(b, output_deriv, results, p, b_bound) |
| @eval function Base.$(g)(f::ForwardOptimize{F}, x::$A, y::TrackedArray{Y,D}) where {F,Y,D} | ||
| result = DiffResults.GradientResult(SVector(zero(Y))) | ||
| df = (vx, vy) -> let vx=vx | ||
| ForwardDiff.gradient!(result, s -> f.f(vx, s[1]), SVector(vy)) |
This PR fixes some problems with the DiffRules integration and its tests. It is needed for JuliaDiff/DiffRules.jl#79 (relevant DiffRules tests pass with that PR).
Mainly, the PR
NaNcomparisons to the tests (necessary since in DiffRules undefined and non-existing derivatives are implemented asNaNand hence otherwise comparisons with ForwardDiff will fail if both returnNaN)NaNthe ForwardDiff results of both approaches are different for the derivative of the other argument, one will beNaNand one might not)map/broadcasting of DiffRules as currently internally derivatives are computed with ForwardDiff always for both arguments, even if only one is tracked (this was uncovered by the changes to the tests mentioned above and it ensures that derivatives of functions where derivatives are defined only for one argument return non-NaNresults, as in ForwardDiff)TheEdit: Fixed on the master branch,vcattest error is unrelated and also present on the master branch and other PRs.I also assume we could do better than ForwardDiff here and also avoid that all results become
NaNif derivatives are computed with respect to both arguments and only one is defined/exists.But replacing ForwardDiff with a direct implementation of the DiffRules-derivatives seemed to require much larger changes, and I tried to apply only a somewhat minimal set of changes required for JuliaDiff/DiffRules.jl#79.