reset Zygote context cache after each call#70
reset Zygote context cache after each call#70marius311 wants to merge 1 commit intoJuliaDiff:masterfrom
Conversation
Codecov ReportBase: 83.33% // Head: 83.85% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #70 +/- ##
==========================================
+ Coverage 83.33% 83.85% +0.51%
==========================================
Files 6 6
Lines 474 483 +9
==========================================
+ Hits 395 405 +10
+ Misses 79 78 -1
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
ChrisRackauckas
left a comment
There was a problem hiding this comment.
This definitely seems like the correct thing to do.
|
I wonder if this is a Zygote issue that should be fixed in Zygote. And more generally, if it is OK at all to mutate RuleConfigs in rrule/frule calls. |
|
I agree with David. This doesn't seem AbstractDifferentiation specific at all. |
|
I think this has to be here since we are calling into |
|
Independent of whether the fix belongs here or in Zygote, marking as draft for now while I think through some consequences of clearing the cache after each jacobian column (which is what this PR currently does) brought up by @ToucheSir on Slack. |
Short of advanced compiler magic which doesn't exist for pure Julia ADs yet, not sure there's a way around storing mutable state in RuleConfigs if one wants to support For the issue at hand, one solution could be to give |
|
Thanks for thinking about this again @ToucheSir. Would be great to get #69 fixed in some way. I'm pretty sure this is above my paygrade though unfortunately, so can't really say I expect to be able to contribute here further. For posterity at least here's the example you gave on Slack that convinced me there may be issues with this PR: julia> using Zygote
julia> function f(x, a)
r = Ref(x)
r[] = r[] + r[]
r[] = r[] * a
r[]
end
f (generic function with 1 method)
julia> Zygote.gradient(f, 1, 3)
(6.0, 2.0)
julia> Zygote.withjacobian(f, [1 2 3], 3)
(val = [6, 12, 18], grad = ([6 0 0; 0 6 0; 0 0 6], [2, 4, 6]))
julia> function withjacobian(f, args...) # same behaviour as Zygote.withjacobian, just with cache clearing after each pullback call
cx = Zygote.Context()
y, back = Zygote.pullback(Zygote._jvec∘f, cx, args...)
out = map(args) do x
T = promote_type(eltype(x), eltype(y))
dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) :
x isa Number ? similar(y, T, length(y)) :
nothing
end
delta = Zygote._eyelike(y)
for k in LinearIndices(y)
grads = back(delta[:,k])
for (dx, grad) in zip(out, grads)
dx isa AbstractArray || continue
Zygote._gradcopy!(view(dx,k,:), grad)
end
@show cx
empty!(cx.cache)
end
(val=y, grad=out)
end
withjacobian (generic function with 1 method)
julia> withjacobian(f, [1 2 3], 3)
cx = Zygote.Context{false}(IdDict{Any, Any}(Base.RefValue{Matrix{Int64}}([6 12 18]) => Base.RefValue{Any}((x = nothing,))))
cx = Zygote.Context{false}(IdDict{Any, Any}(Base.RefValue{Matrix{Int64}}([6 12 18]) => Base.RefValue{Any}((x = [0 1 0],))))
cx = Zygote.Context{false}(IdDict{Any, Any}(Base.RefValue{Matrix{Int64}}([6 12 18]) => Base.RefValue{Any}((x = [0 0 1],))))
(val = [6, 12, 18], grad = ([6 0 0; 0 0 0; 0 0 0], [2, 0, 0]))The However, this PR actually does give the correct result for that jacobian, but you then pointed out its because AD.jl gets a new pullback function for every column, which seems inefficient, so perhaps the "fix" involves considering more moving pieces than I originally thought. |
|
It should be a relatively small change. I'd be happy to file a PR once #68 lands. |
|
the plane landed. |
Fixes #69
I did the reset after to allow GC immediately. Perhaps there's better ways, happy to take comments.