Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
684661a
move using MulAddMacro to right place
oxinabox Jul 7, 2020
90a87ea
Add frule and rrule decorator macros
oxinabox Jul 7, 2020
0ad8235
Update src/rules.jl
oxinabox Jul 10, 2020
8de4ca9
Initial sketch of capturing the AST and feeding it to new rule hooks
oxinabox Jul 17, 2020
4476c4e
sort out API for overload generation
oxinabox Jul 21, 2020
8a1fb9c
add ForwardDiffZero as an API integration test
oxinabox Jul 21, 2020
83d7e8b
Revert "Add frule and rrule decorator macros"
oxinabox Jul 24, 2020
ffeb861
use refresh_rules either manually or autoamtically on pkg load / file…
oxinabox Jul 24, 2020
78fd24f
directly interpolate function type in
oxinabox Jul 24, 2020
0a741db
replace missed opname with op [fixme]
oxinabox Jul 25, 2020
a0129d0
don't handle multi-input
oxinabox Aug 17, 2020
d4efa8e
Add ReverseDiffZero demo
oxinabox Aug 18, 2020
3375791
remove excess new lines
oxinabox Aug 18, 2020
463166d
Update test/demos/reversediffzero.jl
oxinabox Aug 18, 2020
536dac3
Update test/demos/forwarddiffzero.jl
oxinabox Aug 18, 2020
4de2b6b
Update test/demos/reversediffzero.jl
oxinabox Aug 18, 2020
f4ed7c9
Apply suggestions from code review
oxinabox Aug 18, 2020
1872cb1
more comments
oxinabox Aug 18, 2020
37dda58
Apply suggestions from code review
oxinabox Aug 19, 2020
6a04dac
use paritial for all deriviative parts in demos
oxinabox Aug 19, 2020
dd083af
remove debug stuff
oxinabox Aug 19, 2020
562fe72
tweak comments etc
oxinabox Aug 19, 2020
0422e9c
start writing docs for using overload generation
oxinabox Aug 19, 2020
65e8c3d
working on docs
oxinabox Aug 20, 2020
6e58754
finish first pass at docs
oxinabox Aug 20, 2020
faa2087
more docs
oxinabox Aug 20, 2020
a909e2d
handle Unionall Signatures
oxinabox Aug 20, 2020
fb8cdf6
Stop refreshing rules on include_callback
oxinabox Aug 20, 2020
b8d1581
tweaks
oxinabox Aug 20, 2020
27a8592
remove type_constraint_equal
oxinabox Aug 21, 2020
36b6410
Update test/demos/reversediffzero.jl
oxinabox Aug 21, 2020
e87b845
Style and comment fixes
oxinabox Aug 21, 2020
1d91366
Don't export clear_new_rule_hooks!
oxinabox Aug 21, 2020
4ec4981
Update docs/make.jl
oxinabox Aug 21, 2020
baf4431
move comemnt
oxinabox Aug 21, 2020
ada4822
fix dotpoints in docs
oxinabox Aug 21, 2020
cfee703
fix clear rule hooks in tests
oxinabox Aug 21, 2020
3320fba
bump version
oxinabox Aug 21, 2020
fdca95c
Apply suggestions from code review
oxinabox Aug 22, 2020
d97535f
Update docs/src/autodiff/operator_overloading.md
oxinabox Aug 24, 2020
7d32509
More docs on generation
oxinabox Aug 24, 2020
ecd2bb6
test clear hooks
oxinabox Aug 24, 2020
eccb894
wrap up code review
oxinabox Aug 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.5"
version = "0.9.6"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Expand Down
10 changes: 5 additions & 5 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ version = "0.8.2"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "1c593d1efa27437ed9dd365d1143c594b563e138"
git-tree-sha1 = "fb1ff838470573adc15c71ba79f8d31328f035da"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.25.1"
version = "0.25.2"

[[DocumenterTools]]
deps = ["Base64", "DocStringExtensions", "Documenter", "FileWatching", "LibGit2", "Sass"]
Expand Down Expand Up @@ -78,9 +78,9 @@ version = "0.2.2"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "10134f2ee0b1978ae7752c41306e131a684e1f06"
git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.7"
version = "1.0.10"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand All @@ -91,7 +91,7 @@ deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"

[compat]
Documenter = "0.25"
Expand Down
5 changes: 5 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ChainRulesCore
using Documenter
using DocumenterTools: Themes
using Markdown

DocMeta.setdocmeta!(
ChainRulesCore,
Expand Down Expand Up @@ -36,6 +37,10 @@ makedocs(
"Complex Numbers" => "complex.md",
"Deriving Array Rules" => "arrays.md",
"Debug Mode" => "debug_mode.md",
"Usage in AD" => [
"Overview" => "autodiff/overview.md",
"Operator Overloading" => "autodiff/operator_overloading.md"
],
"Design" => [
"Many Differential Types" => "design/many_differentials.md",
],
Expand Down
8 changes: 8 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@ Pages = [
Private = false
```

## Ruleset Loading
```@autodocs
Modules = [ChainRulesCore]
Pages = ["ruleset_loading.jl"]
Private = false
```

## Internal
```@docs
ChainRulesCore.AbstractDifferential
ChainRulesCore.debug_mode
ChainRulesCore.clear_new_rule_hooks!
```
80 changes: 80 additions & 0 deletions docs/src/autodiff/operator_overloading.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Operator Overloading

The principal interface for using the operator overload generation method is [`on_new_rule`](@ref).
This function allows one to register a hook to be run every time a new rule is defined.
The hook receives a signature type-type as input, and generally will use `eval` to define
an overload of an AD system's overloaded type.
For example, using the signature type `Tuple{typeof(+), Real, Real}` to make
`+(::DualNumber, ::DualNumber)` call the `frule` for `+`.
A signature type tuple always has the form:
`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}`, where `pos_arg1` is the
first positional argument.
One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`;
or more simply you can just use conditions for this.
For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write:
```julia
const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}}
function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F
@eval quote
# ...
end
end
define_overload(::Any) = nothing # don't do anything for any other signature

on_new_rule(frule, define_overload)
```

or you might write:
```julia
const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64})
function define_overload(sig) where F
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll,
opT, argTs = Iterators.peel(sig.parameters)
all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return
@eval quote
# ...
end
end

on_new_rule(frule, define_overload)
```

The generation of overloaded code is the responsibility of the AD implementor.
Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this.
Its generally fairly simple, though can become complex if you need to handle complicated type-constraints.
Examples are shown below.

The hook is automatically triggered whenever a package is loaded.
It can also be triggers manually using `refresh_rules`(@ref).
This is useful for example if new rules are define in the REPL, or if a package defining rules is modified.
(Revise.jl will not automatically trigger).
When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on.

`clear_new_rule_hooks!`(@ref) clears all registered hooks.
It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function.

## Examples

### ForwardDiffZero
The overload generation hook in this example is: `define_dual_overload`.

````@eval
using Markdown
Markdown.parse("""
```julia
$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String))
```
""")
````

### ReverseDiffZero
The overload generation hook in this example is: `define_tracked_overload`.

````@eval
using Markdown
Markdown.parse("""
```julia
$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String))
```
""")
````
22 changes: 22 additions & 0 deletions docs/src/autodiff/overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Using ChainRules in your AD system

This section is for authors of AD systems.
It assumes a pretty solid understanding of both Julia and automatic differentiation.
It explains how to make use of ChainRule's "rulesets" ([`frule`](@ref)s, [`rrule`](@ref)s,)
to avoid having to code all your own AD primitives / custom sensitives.

There are 3 main ways to access ChainRules rule sets in your AutoDiff system.

1. [Operation Overloading Generation](operator_overloading.html)
- This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an `rrule`/`frule`.
- A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule.
- This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all.
2. Source code tranform based on inserting branches that check of `rrule`/`frule` return `nothing`
- If the `rrule`/`frule` returns a rule result then use it, if it returns `nothing` then do normal AD path.
- In theory type inference optimizes these branchs out; in practice it may not.
- This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation.
3. Source code transform based on `rrule`/`frule` method-table
- If an applicable `rrule`/`frule` exists in the method table then use it, else generate normal AD path.
- This avoids having branches in your generated code.
- This requires maintaining your own back-edges.
- This is pretty hardcore even by the standard of source code tranformations.
10 changes: 7 additions & 3 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using MuladdMacro: @muladd

export frule, rrule
export @scalar_rule, @thunk
export canonicalize, extern, unthunk
export on_new_rule, refresh_rules # generation tools
export frule, rrule # core function
export @scalar_rule, @thunk # definition helper macros
export canonicalize, extern, unthunk # differential operations
# differentials
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk
export NO_FIELDS

Expand All @@ -20,5 +23,6 @@ include("differential_arithmetic.jl")

include("rules.jl")
include("rule_definition_tools.jl")
include("ruleset_loading.jl")

end # module
2 changes: 0 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# These are some macros (and supporting functions) to make it easier to define rules.
using MuladdMacro: @muladd

"""
@scalar_rule(f(x₁, x₂, ...),
@setup(statement₁, statement₂, ...),
Expand Down
4 changes: 0 additions & 4 deletions src/rules.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#####
##### `frule`/`rrule`
#####

"""
frule((Δf, Δx...), f, x...)

Expand Down
141 changes: 141 additions & 0 deletions src/ruleset_loading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Infastructure to support generating overloads from rules.
function __init__()
# Need to refresh rules when a package is loaded
push!(Base.package_callbacks, pkgid -> refresh_rules())
end

# Holds all the hook functions that are invokes when a new rule is defined
const RRULE_DEFINITION_HOOKS = Function[]
const FRULE_DEFINITION_HOOKS = Function[]
_hook_list(::typeof(rrule)) = RRULE_DEFINITION_HOOKS
_hook_list(::typeof(frule)) = FRULE_DEFINITION_HOOKS

"""
on_new_rule(hook, frule | rrule)

Register a `hook` function to run when new rules are defined.
The hook receives a signature type-type as input, and generally will use `eval` to define
an overload of an AD system's overloaded type
For example, using the signature type `Tuple{typeof(+), Real, Real}` to make
`+(::DualNumber, ::DualNumber)` call the `frule` for `+`.
A signature type tuple always has the form:
`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}...}`, where `pos_arg1` is the
first positional argument.

The hooks are automatically run on new rules whenever a package is loaded.
They can be manually triggered by [`refresh_rules`](@ref).
When a hook is first registered with `on_new_rule` it is run on all existing rules.
"""
function on_new_rule(hook_fun, rule_kind)
# apply the hook to the existing rules
ret = map(_rule_list(rule_kind)) do method
sig = _primal_sig(rule_kind, method)
_safe_hook_fun(hook_fun, sig)
end

# register hook for new rules -- so all new rules get this function applied
push!(_hook_list(rule_kind), hook_fun)
return ret
end

"""
clear_new_rule_hooks!(frule|rrule)

Clears all hooks that were registered with corresponding [`on_new_rule`](@ref).
This is useful for while working interactively to define your rule generating hooks.
If you previously wrong an incorrect hook, you can use this to get rid of the old one.

!!! warning
This absolutely should not be used in a package, as it will break any other AD system
using the rule hooks that might happen to be loaded.
"""
clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind))

"""
_rule_list(frule | rrule)

Returns a list of all the methods of the currently defined rules of the given kind.
Excluding the fallback rule that returns `nothing` for every input.
"""
function _rule_list end
# The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them
_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__)


const LAST_REFRESH_RRULE = Ref(0)
const LAST_REFRESH_FRULE = Ref(0)
last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE
last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE

"""
refresh_rules()
refresh_rules(frule | rrule)

This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules.
It is *automatically* run when ever a package is loaded.
It can also be manually called to run it directly, for example if a rule was defined
in the REPL or within the same file as the AD function.
"""
function refresh_rules()
refresh_rules(frule);
refresh_rules(rrule)
end

function refresh_rules(rule_kind)
isempty(_rule_list(rule_kind)) && return # if no hooks, exit early, nothing to run
already_done_world_age = last_refresh(rule_kind)[]
for method in _rule_list(rule_kind)
_defined_world(method) < already_done_world_age && continue
sig = _primal_sig(rule_kind, method)
_trigger_new_rule_hooks(rule_kind, sig)
end

last_refresh(rule_kind)[] = _current_world()
return nothing
end

@static if VERSION >= v"1.2"
_current_world() = Base.get_world_counter()
_defined_world(method) = method.primary_world
else
_current_world() = ccall(:jl_get_world_counter, UInt, ())
_defined_world(method) = method.min_world
end

"""
_primal_sig(frule|rule, rule_method | rule_sig)

Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`.
"""
_primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig)
function _primal_sig(::typeof(frule), rule_sig::DataType)
@assert rule_sig.parameters[1] == typeof(frule)
# need to skip frule and the deriviative info, so starting from the 3rd
return Tuple{rule_sig.parameters[3:end]...}
end
function _primal_sig(::typeof(rrule), rule_sig::DataType)
@assert rule_sig.parameters[1] == typeof(rrule)
# need to skip rrule so starting from the 2rd
return Tuple{rule_sig.parameters[2:end]...}
end
function _primal_sig(rule_kind, rule_sig::UnionAll)
# This looks a lot like Base.unwrap_unionall and Base.rewrap_unionall, but using those
# seems not to work
p_sig = _primal_sig(rule_kind, rule_sig.body)
return UnionAll(rule_sig.var, p_sig)
end


function _trigger_new_rule_hooks(rule_kind, sig)
for hook_fun in _hook_list(rule_kind)
_safe_hook_fun(hook_fun, sig)
end
end

function _safe_hook_fun(hook_fun, sig)
try
hook_fun(sig)
catch err
@error "Error triggering hook" hook_fun sig exception=err
end
end
Loading