Skip to content

reset Zygote context cache after each call#70

Closed
marius311 wants to merge 1 commit intoJuliaDiff:masterfrom
marius311:empty_cache
Closed

reset Zygote context cache after each call#70
marius311 wants to merge 1 commit intoJuliaDiff:masterfrom
marius311:empty_cache

Conversation

@marius311
Copy link

Fixes #69

I did the reset after to allow GC immediately. Perhaps there's better ways, happy to take comments.

@marius311 marius311 changed the title reset Zygote context cache before each call reset Zygote context cache after each call Jan 5, 2023
@codecov-commenter
Copy link

codecov-commenter commented Jan 5, 2023

Codecov Report

Base: 83.33% // Head: 83.85% // Increases project coverage by +0.51% 🎉

Coverage data is based on head (48d8ee4) compared to base (eb5d913).
Patch coverage: 88.88% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #70      +/-   ##
==========================================
+ Coverage   83.33%   83.85%   +0.51%     
==========================================
  Files           6        6              
  Lines         474      483       +9     
==========================================
+ Hits          395      405      +10     
+ Misses         79       78       -1     
Impacted Files Coverage Δ
src/ruleconfig.jl 90.00% <83.33%> (-10.00%) ⬇️
src/AbstractDifferentiation.jl 79.24% <100.00%> (+0.82%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Member

@ChrisRackauckas ChrisRackauckas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definitely seems like the correct thing to do.

@devmotion
Copy link
Member

I wonder if this is a Zygote issue that should be fixed in Zygote. And more generally, if it is OK at all to mutate RuleConfigs in rrule/frule calls.

@mohdibntarek
Copy link
Member

I agree with David. This doesn't seem AbstractDifferentiation specific at all.

@marius311
Copy link
Author

I think this has to be here since we are calling into ChainRulesCore.rrule_via_ad, but you can't clear the cache anywhere inside that function or what it calls since that has no guarantees not to be nested in fact often is; AD.pullback_function is the best indication this is a top-level call, ie, its the place to clear the cache.

@marius311 marius311 marked this pull request as draft January 5, 2023 21:14
@marius311
Copy link
Author

Independent of whether the fix belongs here or in Zygote, marking as draft for now while I think through some consequences of clearing the cache after each jacobian column (which is what this PR currently does) brought up by @ToucheSir on Slack.

@ToucheSir
Copy link

And more generally, if it is OK at all to mutate RuleConfigs in rrule/frule calls.

Short of advanced compiler magic which doesn't exist for pure Julia ADs yet, not sure there's a way around storing mutable state in RuleConfigs if one wants to support setfield!.

For the issue at hand, one solution could be to give ZygoteBackend its own @primitive definitions which create a fresh RuleConfig on each call. That would avoid any jacobian pullback reuse issues and keep any Zygote-specific modifications contained within the ZygoteBackend source.

@marius311
Copy link
Author

marius311 commented Mar 1, 2023

Thanks for thinking about this again @ToucheSir. Would be great to get #69 fixed in some way. I'm pretty sure this is above my paygrade though unfortunately, so can't really say I expect to be able to contribute here further.

For posterity at least here's the example you gave on Slack that convinced me there may be issues with this PR:

julia> using Zygote

julia> function f(x, a)
           r = Ref(x)
           r[] = r[] + r[]
           r[] = r[] * a
           r[]
       end
f (generic function with 1 method)

julia> Zygote.gradient(f, 1, 3)
(6.0, 2.0)

julia> Zygote.withjacobian(f, [1 2 3], 3)
(val = [6, 12, 18], grad = ([6 0 0; 0 6 0; 0 0 6], [2, 4, 6]))

julia> function withjacobian(f, args...) # same behaviour as Zygote.withjacobian, just with cache clearing after each pullback call
         cx = Zygote.Context()
         y, back = Zygote.pullback(Zygote._jvecf, cx, args...)
         out = map(args) do x
           T = promote_type(eltype(x), eltype(y))
           dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) :
             x isa Number ? similar(y, T, length(y)) :
             nothing
         end
         delta = Zygote._eyelike(y)
         for k in LinearIndices(y)
           grads = back(delta[:,k])
           for (dx, grad) in zip(out, grads)
             dx isa AbstractArray || continue
             Zygote._gradcopy!(view(dx,k,:), grad)
           end
           @show cx
           empty!(cx.cache)
         end
         (val=y, grad=out)
       end
withjacobian (generic function with 1 method)

julia> withjacobian(f, [1 2 3], 3)
cx = Zygote.Context{false}(IdDict{Any, Any}(Base.RefValue{Matrix{Int64}}([6 12 18]) => Base.RefValue{Any}((x = nothing,))))
cx = Zygote.Context{false}(IdDict{Any, Any}(Base.RefValue{Matrix{Int64}}([6 12 18]) => Base.RefValue{Any}((x = [0 1 0],))))
cx = Zygote.Context{false}(IdDict{Any, Any}(Base.RefValue{Matrix{Int64}}([6 12 18]) => Base.RefValue{Any}((x = [0 0 1],))))
(val = [6, 12, 18], grad = ([6 0 0; 0 0 0; 0 0 0], [2, 0, 0]))

The function withjacobian definition tries to emulate this PR by clearing the cache after each column, but gives an incorrect grad at the end.

However, this PR actually does give the correct result for that jacobian, but you then pointed out its because AD.jl gets a new pullback function for every column, which seems inefficient, so perhaps the "fix" involves considering more moving pieces than I originally thought.

@ToucheSir
Copy link

It should be a relatively small change. I'd be happy to file a PR once #68 lands.

@ChrisRackauckas
Copy link
Member

the plane landed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Zygote context cache incorrectly(?) persists between AD calls

6 participants