Allow AD systems to register hooks so they can create new overloads in response to new rules #182
Allow AD systems to register hooks so they can create new overloads in response to new rules #182
Conversation
|
it's difficult to see that i.e. since this PR does nothing it's difficult to see why it shouldn't wait til it does something / be part of a PR that actually adds functionality. I'm sure it is going to be useful... but it's pretty weird to add in isolation. |
|
Getting these annotation into ChainRules.jl, even if they do nothing open us up to solving either of the two proposals. It could wait for the first of those to be added. Psedudocode for how this would be used to do: #127 (comment) |
|
@nickrobinson251 so you are declining this PR, and would like one of the proposals that need it to be implemented without this part way step? |
Persoanlly I would prefer us not to add this to the code until we are actually making use of it (i.e. til it does something). I'm happy for this to be it's own PR, that goes in alongside (/immediately before) other PRs that make it do something useful. Basically I just don't want us to have |
44d2e1f to
f3e1637
Compare
Completely rewritten to actually use the macros now
| end | ||
| # @show fdef | ||
| @eval $fdef | ||
| end |
There was a problem hiding this comment.
because this code never uses the AST I am now questioning if we even need to capture that.
If we don't then that is great because we can get rid of the FRULES and RRULES,
and just use methods(frule) and methods(rrule) instead.
There was a problem hiding this comment.
We maybe could even get rid of the @frule and @rrule macros then, and the magic that detects the latest defined method, and go back to original proposal of using hooks attached to on_package_load
There was a problem hiding this comment.
using hooks attached to
on_package_load
What is this proposal? Is it written down somewhere? If the hooks are not triggered by @frule/@rrule macros, how are they triggered? Some Revise-style magic?
There was a problem hiding this comment.
It is mentioned here
#127 (comment)
and yes it is the same hook in Base that power's Revise:
Base.package_callbacks
it would be triggerd whenever a package is loaded.
Probably we would provide a manual ChainRulesCore.refresh also for people to use in the REPL
There was a problem hiding this comment.
If @eval will be called by the AD package, then we run into namespace issues. The correct behavior for dispatch-based AD is to overload the AD package's methods in the correct namespace outside the AD package, i.e. in the module where the rule is defined. Is this possible using this hook mechanism?
There was a problem hiding this comment.
But what if Base is replaced with another module and ... uses code from that other module, and this module is not available in ForwardDiffZero?
There was a problem hiding this comment.
@eval here evaluates the code in ForwardDiffZero so it will complain if we use code not loaded in ForwardDiffZero iiuc.
There was a problem hiding this comment.
But what if Base is replaced with another module and ... uses code from that other module, and this module is not available in ForwardDiffZero?
the ... is being written by the author of ForwardDiffZero and i don't think has any reason to run anything other than functions from ForwardDiffZero, or from ChainRuleCore (which is loaded by ForwardDiffZero).
Probably this is a reason not to include the AST as trying to use that for something in ... will run into this.
I think its just not needed anyway.
The interesting bit is I guess is in the sig, for a Source to Source AD, like Zygote.
that is not generating overloads of op(overloaded_equiv.(args)....), but rather pullback(op, args...)
Maybe the types of the args would need the same opname = :($(parentmodule(op)).$(nameof(op))) type escaping,
though even that wouldn't work because it would be a Symbol.
But can it be a type directly and thus not need to be given a path in local scope?
Yes, that seems to work
julia> K = Base.Fix2
Base.Fix2
julia> eval(:(foo(x::$K) = x))
foo (generic function with 1 method)
julia> foo(Base.Fix2(+,1))
(::Base.Fix2{typeof(+),Int64}) (generic function with 1 method)There was a problem hiding this comment.
We can also do this trick to avoid qualifying names for the operation.
Though it has to be in call overload form for some reason.
(Still i guess that is fine since more generic if also overloading functors)
julia> struct Foo end
julia> K = typeof(Base.:+) # this would come in the sig tuple, just using Base as example
typeof(+)
julia> @eval (::$K)(::Foo, ::Foo) = 2
julia> +(Foo(), Foo())
2
julia> +(2,1)
3
julia> methods(+, (Foo, Foo))
# 1 method for generic function "+":
[1] +(::Foo, ::Foo) in Main at REPL[17]:1
julia> methods(+)
# 167 methods for generic function "+":
[1] +(x::Bool, z::Complex{Bool}) in Base at complex.jl:282
[2] +(x::Bool, y::Bool) in Base at bool.jl:96
[3] +(x::Bool) in Base at bool.jl:93There was a problem hiding this comment.
I am now using this trick in both demos.
Can this be considered resolved?
willtebbutt
left a comment
There was a problem hiding this comment.
This is looking great so far. No serious concerns. It would be great if @mohamed82008 could comment as he's going to be the first proper consumer of this stuff in ReverseDiff.
| end | ||
| # @show fdef | ||
| @eval $fdef | ||
| end |
There was a problem hiding this comment.
using hooks attached to
on_package_load
What is this proposal? Is it written down somewhere? If the hooks are not triggered by @frule/@rrule macros, how are they triggered? Some Revise-style magic?
test/demos/forwarddiffzero.jl
Outdated
| function derv(f, args...) | ||
| duals = Dual.(args, one.(args)) | ||
| return diff(f(duals...)) | ||
| end |
There was a problem hiding this comment.
60 LOC for an AD system's pretty good going 👏
There was a problem hiding this comment.
90 LOC for reverse mode. 😁
f6a7179 to
c632d35
Compare
|
No more macros. |
nickrobinson251
left a comment
There was a problem hiding this comment.
Assuming this gives us everything we need (for #127), this seems an amazingly efficient use of code. Good work!
src/rules.jl
Outdated
| """ | ||
| refresh_rules() = (refresh_rules(frule); refresh_rules(rrule)) | ||
| function refresh_rules(rule_kind) | ||
| already_done_world_age = last_refresh(rule_kind)[] |
There was a problem hiding this comment.
should we check here if there are any hooks regisistered, and if not bail out early?
So that if someone has something that uses ChainRulesCore, but doesn't use any overloadinging AD package, then we don't spend the time going though the method table?
There was a problem hiding this comment.
Just a couple of truly minor points. Otherwise, this looks really good.
I particularly like that you make propagate a closure, so that a given Tracked knows how to propagate the output of its pullback to its parents. Very clean.
When I wrote Nabla.jl I had a separate bit of control that handled accumulating things to the right places.
The include callback runs before the file is included, so it is not useful to us. I tested and the Package load callback runs after, so it is useful.
|
I had to remove the hook triggering on As i mentioned above we may like to use
|
nickrobinson251
left a comment
There was a problem hiding this comment.
Looks great. A few tiny comments from a quick read through. I'll finish reviewing tomorrow :)
Co-authored-by: Nick Robinson <npr251@gmail.com>
Co-authored-by: Nick Robinson <npr251@gmail.com>
Co-authored-by: Nick Robinson <npr251@gmail.com>
willtebbutt
left a comment
There was a problem hiding this comment.
This is nearly there. Just a bunch of typos / docs improvements + a suggestion for an extra test.
nickrobinson251
left a comment
There was a problem hiding this comment.
Finally got round to reading the docs. Really good work on those too!
Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> Co-authored-by: Nick Robinson <npr251@gmail.com>
Docs preview: https://www.juliadiff.org/ChainRulesCore.jl/previews/PR182/
Closes #127
Basically we require the rule definitions to be wrapped in macros then we scope up the AST, then evaluate the AST outselves then fineout what method was just defined.Then we remember what method goes to what AST in a rule list. And we trigger hooks that the AD system can register which will be used to generate a new AST that it will `eval` to define its own rule equivelent overloads. Or it might choose not to, it can use `method.sig` to determine if this is a rule it wants to deal with. When the AD is initially loaded it should trigger its hooks on the whole rule_list.TODO:
This PR adds macros to wrap all defintions of rules. Right now they do nothing at all. We may like to later tag a breaking change for when they are actually doing something and are thus required.But I am introducing them now so we can do things with them later.
in particular #127 just capturing the AST at the time it is created seems like a much simplier way to accomplish this goal.
Related: #44, e.g. this will allow us to setup splatting to work in
frule((partials...), f, args...).Once this is merged I will make the follow-ups to ChainRules and ChainRulesTestUtils.