-
Notifications
You must be signed in to change notification settings - Fork 65
Allow AD systems to register hooks so they can create new overloads in response to new rules #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
684661a
move using MulAddMacro to right place
oxinabox 90a87ea
Add frule and rrule decorator macros
oxinabox 0ad8235
Update src/rules.jl
oxinabox 8de4ca9
Initial sketch of capturing the AST and feeding it to new rule hooks
oxinabox 4476c4e
sort out API for overload generation
oxinabox 8a1fb9c
add ForwardDiffZero as an API integration test
oxinabox 83d7e8b
Revert "Add frule and rrule decorator macros"
oxinabox ffeb861
use refresh_rules either manually or autoamtically on pkg load / file…
oxinabox 78fd24f
directly interpolate function type in
oxinabox 0a741db
replace missed opname with op [fixme]
oxinabox a0129d0
don't handle multi-input
oxinabox d4efa8e
Add ReverseDiffZero demo
oxinabox 3375791
remove excess new lines
oxinabox 463166d
Update test/demos/reversediffzero.jl
oxinabox 536dac3
Update test/demos/forwarddiffzero.jl
oxinabox 4de2b6b
Update test/demos/reversediffzero.jl
oxinabox f4ed7c9
Apply suggestions from code review
oxinabox 1872cb1
more comments
oxinabox 37dda58
Apply suggestions from code review
oxinabox 6a04dac
use paritial for all deriviative parts in demos
oxinabox dd083af
remove debug stuff
oxinabox 562fe72
tweak comments etc
oxinabox 0422e9c
start writing docs for using overload generation
oxinabox 65e8c3d
working on docs
oxinabox 6e58754
finish first pass at docs
oxinabox faa2087
more docs
oxinabox a909e2d
handle Unionall Signatures
oxinabox fb8cdf6
Stop refreshing rules on include_callback
oxinabox b8d1581
tweaks
oxinabox 27a8592
remove type_constraint_equal
oxinabox 36b6410
Update test/demos/reversediffzero.jl
oxinabox e87b845
Style and comment fixes
oxinabox 1d91366
Don't export clear_new_rule_hooks!
oxinabox 4ec4981
Update docs/make.jl
oxinabox baf4431
move comemnt
oxinabox ada4822
fix dotpoints in docs
oxinabox cfee703
fix clear rule hooks in tests
oxinabox 3320fba
bump version
oxinabox fdca95c
Apply suggestions from code review
oxinabox d97535f
Update docs/src/autodiff/operator_overloading.md
oxinabox 7d32509
More docs on generation
oxinabox ecd2bb6
test clear hooks
oxinabox eccb894
wrap up code review
oxinabox File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| # Operator Overloading | ||
|
|
||
| The principal interface for using the operator overload generation method is [`on_new_rule`](@ref). | ||
| This function allows one to register a hook to be run every time a new rule is defined. | ||
| The hook receives a signature type-type as input, and generally will use `eval` to define | ||
| an overload of an AD system's overloaded type. | ||
| For example, using the signature type `Tuple{typeof(+), Real, Real}` to make | ||
| `+(::DualNumber, ::DualNumber)` call the `frule` for `+`. | ||
| A signature type tuple always has the form: | ||
| `Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}`, where `pos_arg1` is the | ||
| first positional argument. | ||
| One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`; | ||
| or more simply you can just use conditions for this. | ||
| For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write: | ||
| ```julia | ||
| const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}} | ||
| function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F | ||
| @eval quote | ||
| # ... | ||
| end | ||
| end | ||
| define_overload(::Any) = nothing # don't do anything for any other signature | ||
|
|
||
| on_new_rule(frule, define_overload) | ||
| ``` | ||
|
|
||
| or you might write: | ||
| ```julia | ||
| const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64}) | ||
| function define_overload(sig) where F | ||
| sig = Base.unwrap_unionall(sig) # not really handling most UnionAll, | ||
| opT, argTs = Iterators.peel(sig.parameters) | ||
| all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return | ||
| @eval quote | ||
| # ... | ||
| end | ||
| end | ||
|
|
||
| on_new_rule(frule, define_overload) | ||
| ``` | ||
|
|
||
| The generation of overloaded code is the responsibility of the AD implementor. | ||
| Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this. | ||
| Its generally fairly simple, though can become complex if you need to handle complicated type-constraints. | ||
| Examples are shown below. | ||
|
|
||
| The hook is automatically triggered whenever a package is loaded. | ||
| It can also be triggers manually using `refresh_rules`(@ref). | ||
| This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. | ||
| (Revise.jl will not automatically trigger). | ||
| When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on. | ||
|
|
||
| `clear_new_rule_hooks!`(@ref) clears all registered hooks. | ||
| It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function. | ||
|
|
||
| ## Examples | ||
|
|
||
| ### ForwardDiffZero | ||
| The overload generation hook in this example is: `define_dual_overload`. | ||
|
|
||
| ````@eval | ||
| using Markdown | ||
| Markdown.parse(""" | ||
| ```julia | ||
| $(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) | ||
| ``` | ||
| """) | ||
| ```` | ||
|
|
||
| ### ReverseDiffZero | ||
| The overload generation hook in this example is: `define_tracked_overload`. | ||
|
|
||
| ````@eval | ||
| using Markdown | ||
| Markdown.parse(""" | ||
| ```julia | ||
| $(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) | ||
| ``` | ||
| """) | ||
| ```` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # Using ChainRules in your AD system | ||
|
|
||
| This section is for authors of AD systems. | ||
| It assumes a pretty solid understanding of both Julia and automatic differentiation. | ||
| It explains how to make use of ChainRule's "rulesets" ([`frule`](@ref)s, [`rrule`](@ref)s,) | ||
| to avoid having to code all your own AD primitives / custom sensitives. | ||
|
|
||
| There are 3 main ways to access ChainRules rule sets in your AutoDiff system. | ||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| 1. [Operation Overloading Generation](operator_overloading.html) | ||
| - This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an `rrule`/`frule`. | ||
| - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. | ||
| - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. | ||
| 2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing` | ||
| - If the `rrule`/`frule` returns a rule result then use it, if it returns `nothing` then do normal AD path. | ||
| - In theory type inference optimizes these branchs out; in practice it may not. | ||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| - This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation. | ||
| 3. Source code transform based on `rrule`/`frule` method-table | ||
| - If an applicable `rrule`/`frule` exists in the method table then use it, else generate normal AD path. | ||
| - This avoids having branches in your generated code. | ||
| - This requires maintaining your own back-edges. | ||
| - This is pretty hardcore even by the standard of source code tranformations. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,3 @@ | ||
| ##### | ||
| ##### `frule`/`rrule` | ||
| ##### | ||
|
|
||
| """ | ||
| frule((Δf, Δx...), f, x...) | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| # Infastructure to support generating overloads from rules. | ||
| function __init__() | ||
| # Need to refresh rules when a package is loaded | ||
| push!(Base.package_callbacks, pkgid -> refresh_rules()) | ||
| end | ||
|
|
||
| # Holds all the hook functions that are invokes when a new rule is defined | ||
| const RRULE_DEFINITION_HOOKS = Function[] | ||
| const FRULE_DEFINITION_HOOKS = Function[] | ||
| _hook_list(::typeof(rrule)) = RRULE_DEFINITION_HOOKS | ||
| _hook_list(::typeof(frule)) = FRULE_DEFINITION_HOOKS | ||
|
|
||
| """ | ||
| on_new_rule(hook, frule | rrule) | ||
|
|
||
| Register a `hook` function to run when new rules are defined. | ||
| The hook receives a signature type-type as input, and generally will use `eval` to define | ||
| an overload of an AD system's overloaded type | ||
| For example, using the signature type `Tuple{typeof(+), Real, Real}` to make | ||
| `+(::DualNumber, ::DualNumber)` call the `frule` for `+`. | ||
| A signature type tuple always has the form: | ||
| `Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}...}`, where `pos_arg1` is the | ||
| first positional argument. | ||
|
|
||
| The hooks are automatically run on new rules whenever a package is loaded. | ||
| They can be manually triggered by [`refresh_rules`](@ref). | ||
| When a hook is first registered with `on_new_rule` it is run on all existing rules. | ||
| """ | ||
| function on_new_rule(hook_fun, rule_kind) | ||
| # apply the hook to the existing rules | ||
| ret = map(_rule_list(rule_kind)) do method | ||
| sig = _primal_sig(rule_kind, method) | ||
| _safe_hook_fun(hook_fun, sig) | ||
| end | ||
|
|
||
| # register hook for new rules -- so all new rules get this function applied | ||
| push!(_hook_list(rule_kind), hook_fun) | ||
| return ret | ||
| end | ||
|
|
||
| """ | ||
| clear_new_rule_hooks!(frule|rrule) | ||
|
|
||
| Clears all hooks that were registered with corresponding [`on_new_rule`](@ref). | ||
| This is useful for while working interactively to define your rule generating hooks. | ||
| If you previously wrong an incorrect hook, you can use this to get rid of the old one. | ||
|
|
||
| !!! warning | ||
| This absolutely should not be used in a package, as it will break any other AD system | ||
| using the rule hooks that might happen to be loaded. | ||
| """ | ||
| clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) | ||
|
|
||
| """ | ||
| _rule_list(frule | rrule) | ||
|
|
||
| Returns a list of all the methods of the currently defined rules of the given kind. | ||
| Excluding the fallback rule that returns `nothing` for every input. | ||
| """ | ||
| function _rule_list end | ||
| # The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them | ||
| _rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) | ||
|
|
||
|
|
||
| const LAST_REFRESH_RRULE = Ref(0) | ||
| const LAST_REFRESH_FRULE = Ref(0) | ||
| last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE | ||
| last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE | ||
|
|
||
| """ | ||
| refresh_rules() | ||
| refresh_rules(frule | rrule) | ||
|
|
||
| This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. | ||
| It is *automatically* run when ever a package is loaded. | ||
| It can also be manually called to run it directly, for example if a rule was defined | ||
| in the REPL or within the same file as the AD function. | ||
| """ | ||
| function refresh_rules() | ||
| refresh_rules(frule); | ||
| refresh_rules(rrule) | ||
| end | ||
|
|
||
| function refresh_rules(rule_kind) | ||
| isempty(_rule_list(rule_kind)) && return # if no hooks, exit early, nothing to run | ||
| already_done_world_age = last_refresh(rule_kind)[] | ||
| for method in _rule_list(rule_kind) | ||
| _defined_world(method) < already_done_world_age && continue | ||
| sig = _primal_sig(rule_kind, method) | ||
| _trigger_new_rule_hooks(rule_kind, sig) | ||
| end | ||
|
|
||
| last_refresh(rule_kind)[] = _current_world() | ||
| return nothing | ||
| end | ||
|
|
||
| @static if VERSION >= v"1.2" | ||
| _current_world() = Base.get_world_counter() | ||
| _defined_world(method) = method.primary_world | ||
| else | ||
| _current_world() = ccall(:jl_get_world_counter, UInt, ()) | ||
| _defined_world(method) = method.min_world | ||
| end | ||
|
|
||
| """ | ||
| _primal_sig(frule|rule, rule_method | rule_sig) | ||
|
|
||
| Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`. | ||
| """ | ||
| _primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig) | ||
| function _primal_sig(::typeof(frule), rule_sig::DataType) | ||
| @assert rule_sig.parameters[1] == typeof(frule) | ||
| # need to skip frule and the deriviative info, so starting from the 3rd | ||
| return Tuple{rule_sig.parameters[3:end]...} | ||
| end | ||
| function _primal_sig(::typeof(rrule), rule_sig::DataType) | ||
| @assert rule_sig.parameters[1] == typeof(rrule) | ||
| # need to skip rrule so starting from the 2rd | ||
| return Tuple{rule_sig.parameters[2:end]...} | ||
| end | ||
| function _primal_sig(rule_kind, rule_sig::UnionAll) | ||
| # This looks a lot like Base.unwrap_unionall and Base.rewrap_unionall, but using those | ||
| # seems not to work | ||
| p_sig = _primal_sig(rule_kind, rule_sig.body) | ||
| return UnionAll(rule_sig.var, p_sig) | ||
| end | ||
|
|
||
|
|
||
| function _trigger_new_rule_hooks(rule_kind, sig) | ||
| for hook_fun in _hook_list(rule_kind) | ||
| _safe_hook_fun(hook_fun, sig) | ||
| end | ||
| end | ||
|
|
||
| function _safe_hook_fun(hook_fun, sig) | ||
| try | ||
| hook_fun(sig) | ||
| catch err | ||
| @error "Error triggering hook" hook_fun sig exception=err | ||
| end | ||
| end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.