diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml
index 8cf6bb1..139a131 100644
--- a/.github/workflows/IntegrationTest.yml
+++ b/.github/workflows/IntegrationTest.yml
@@ -15,13 +15,7 @@ jobs:
julia-version: [1.5]
os: [ubuntu-latest]
package:
- - {user: JuliaDiff, repo: ChainRules.jl}
- - {user: JuliaMath, repo: SpecialFunctions.jl}
- - {user: invenia, repo: BlockDiagonals.jl}
- - {user: invenia, repo: PDMatsExtras.jl}
- - {user: chrisbrahms, repo: Hankel.jl}
- - {user: SciML, repo: DiffEqBase.jl}
- - {user: dfdx, repo: Yota.jl}
+ # - {user: Invenia, repo: Nabla.jl}
steps:
- uses: actions/checkout@v2
@@ -43,7 +37,7 @@ jobs:
# force it to use this PR's version of the package
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
Pkg.update()
- Pkg.test() # resolver may fail with test time deps
+ Pkg.test() # resolver may fail with test time deps
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
# If we can't resolve that means this is incompatible by SemVer and this is fine
diff --git a/LICENSE.md b/LICENSE.md
index d8cddc9..956cad7 100644
--- a/LICENSE.md
+++ b/LICENSE.md
@@ -1,7 +1,7 @@
-The ChainRulesCore.jl package is licensed under the MIT "Expat" License:
+The ChainRulesOverloadGeneration.jl package is licensed under the MIT "Expat" License:
> Copyright (c) 2018-2019: Jarrett Revels, and other JuliaDiff Contributors:
-> https://github.com/JuliaDiff/ChainRulesCore.jl/contributors
+> https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/contributors
>
> Permission is hereby granted, free of charge, to any person obtaining a copy
> of this software and associated documentation files (the "Software"), to deal
diff --git a/Project.toml b/Project.toml
new file mode 100644
index 0000000..ed060f4
--- /dev/null
+++ b/Project.toml
@@ -0,0 +1,15 @@
+name = "ChainRulesOverloadGeneration"
+uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
+version = "0.1.0"
+
+[deps]
+ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+
+[compat]
+ChainRulesCore = "0.9"
+
+[extras]
+Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+
+[targets]
+test = ["Test"]
diff --git a/README.md b/README.md
index 79b3cdc..6d4ada9 100644
--- a/README.md
+++ b/README.md
@@ -1,19 +1,16 @@
-# ChainRulesCore
+# ChainRulesOverloadGeneration
-[](https://github.com/JuliaDiff/ChainRulesCore.jl/actions?query=workflow:CI)
-[](https://codecov.io/gh/JuliaDiff/ChainRulesCore.jl)
+[](https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/actions?query=workflow:CI)
+[](https://codecov.io/gh/JuliaDiff/ChainRulesOverloadGeneration.jl)
[](https://github.com/invenia/BlueStyle)
[](https://github.com/SciML/ColPrac)
[](https://zenodo.org/badge/latestdoi/199721843)
**Docs:**
-[](https://juliadiff.org/ChainRulesCore.jl/dev)
-[](https://juliadiff.org/ChainRulesCore.jl/stable)
+[](https://juliadiff.org/ChainRulesOverloadGeneration.jl/dev)
+[](https://juliadiff.org/ChainRulesOverloadGeneration.jl/stable)
-The ChainRulesCore package provides a light-weight dependency for defining sensitivities for functions in your packages, without you needing to depend on ChainRules itself.
-
-This will allow your package to be used with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), which aims to provide a variety of common utilities that can be used by downstream automatic differentiation (AD) tools to define and execute forward-, reverse-, and mixed-mode primitives.
-
-This package is a work in progress; PRs welcome!
+The ChainRulesOverloadGeneration package provides a suite of methods for using [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) rules in operator overloading AD systems.
+It tracks what rules are defined at any point in time, and lets you trigger functions to which can use `@eval` in order to define the matching operator overloads.
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index d0f0169..3b5c092 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -11,15 +11,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
-path = ".."
+git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.44"
+[[ChainRulesOverloadGeneration]]
+deps = ["ChainRulesCore"]
+path = ".."
+uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
+version = "0.1.0"
+
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
-git-tree-sha1 = "0900bc19193b8e672d9cd477e6cd92d9e7c02f99"
+git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
-version = "3.29.0"
+version = "3.30.0"
[[Dates]]
deps = ["Printf"]
diff --git a/docs/Project.toml b/docs/Project.toml
index cab507d..f3eff3a 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -1,5 +1,6 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
diff --git a/docs/make.jl b/docs/make.jl
index 0115025..5fa74e7 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -1,72 +1,32 @@
-using ChainRulesCore
+using ChainRulesOverloadGeneration
+using ChainRulesCore: ChainRulesCore
using Documenter
using DocThemeIndigo
using Markdown
-DocMeta.setdocmeta!(
- ChainRulesCore,
- :DocTestSetup,
- quote
- using Random
- Random.seed!(0) # frule doctest shows output
-
- using ChainRulesCore
- # These rules are all actually defined in ChainRules.jl, but we redefine them here to
- # avoid the dependency.
- @scalar_rule(sin(x), cos(x)) # frule and rrule doctest
- @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest
- @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest
- end
-)
-
-indigo = DocThemeIndigo.install(ChainRulesCore)
+indigo = DocThemeIndigo.install(ChainRulesOverloadGeneration)
makedocs(
- modules=[ChainRulesCore],
+ modules=[ChainRulesOverloadGeneration],
format=Documenter.HTML(
prettyurls=false,
assets=[indigo],
- mathengine=MathJax3(
- Dict(
- :tex => Dict(
- "inlineMath" => [["\$","\$"], ["\\(","\\)"]],
- "tags" => "ams",
- # TODO: remove when using physics package
- "macros" => Dict(
- "ip" => ["{\\left\\langle #1, #2 \\right\\rangle}", 2],
- "Re" => "{\\operatorname{Re}}",
- "Im" => "{\\operatorname{Im}}",
- "tr" => "{\\operatorname{tr}}",
- ),
- ),
- ),
- ),
),
- sitename="ChainRules",
- authors="Jarrett Revels and other contributors",
+ sitename="ChainRules Overload Generation",
+ authors="Lyndon White and other contributors",
pages=[
"Introduction" => "index.md",
- "FAQ" => "FAQ.md",
- "Writing Good Rules" => "writing_good_rules.md",
- "Complex Numbers" => "complex.md",
- "Deriving Array Rules" => "arrays.md",
- "Debug Mode" => "debug_mode.md",
- "Gradient Accumulation" => "gradient_accumulation.md",
- "Usage in AD" => [
- "Overview" => "autodiff/overview.md",
- "Operator Overloading" => "autodiff/operator_overloading.md",
- ],
- "Design" => [
- "Changing the Primal" => "design/changing_the_primal.md",
- "Many Differential Types" => "design/many_differentials.md",
+ "Examples of making AD systems" => [
+ "Forward Mode" => "examples/forward_mode.md",
+ "Reverse Mode" => "examples/reverse_mode.md",
],
"API" => "api.md",
- ],
+ ],
strict=true,
checkdocs=:exports,
)
deploydocs(
- repo = "github.com/JuliaDiff/ChainRulesCore.jl.git",
+ repo = "github.com/JuliaDiff/ChainRulesOverloadGeneration.jl.git",
push_preview=true,
)
diff --git a/docs/src/api.md b/docs/src/api.md
index d6db494..e3c80e0 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -1,49 +1,12 @@
# API Documentation
-## Rules
```@autodocs
-Modules = [ChainRulesCore]
-Pages = ["rules.jl"]
-Private = false
-```
-
-## Rule Definition Tools
-```@autodocs
-Modules = [ChainRulesCore]
-Pages = ["rule_definition_tools.jl"]
-Private = false
-```
-
-## Differentials
-```@autodocs
-Modules = [ChainRulesCore]
-Pages = [
- "differentials/abstract_zero.jl",
- "differentials/one.jl",
- "differentials/composite.jl",
- "differentials/thunks.jl",
- "differentials/abstract_differential.jl",
- "differentials/notimplemented.jl",
-]
-Private = false
-```
-
-## Accumulation
-```@docs
-add!!
-ChainRulesCore.is_inplaceable_destination
-```
-
-## Ruleset Loading
-```@autodocs
-Modules = [ChainRulesCore]
+Modules = [ChainRulesOverloadGeneration]
Pages = ["ruleset_loading.jl"]
Private = false
```
## Internal
```@docs
-ChainRulesCore.AbstractTangent
-ChainRulesCore.debug_mode
-ChainRulesCore.clear_new_rule_hooks!
+ChainRulesOverloadGeneration.clear_new_rule_hooks!
```
diff --git a/docs/src/assets/logo.png b/docs/src/assets/logo.png
new file mode 100644
index 0000000..e6ebdcd
Binary files /dev/null and b/docs/src/assets/logo.png differ
diff --git a/docs/src/assets/logo.svg b/docs/src/assets/logo.svg
new file mode 100644
index 0000000..1e8f0ad
--- /dev/null
+++ b/docs/src/assets/logo.svg
@@ -0,0 +1,27 @@
+
+
diff --git a/docs/src/examples/forward_mode.md b/docs/src/examples/forward_mode.md
new file mode 100644
index 0000000..31a632e
--- /dev/null
+++ b/docs/src/examples/forward_mode.md
@@ -0,0 +1,14 @@
+# ForwardDiffZero
+This is a fairly standard operator overloading-based forward mode AD system.
+It defines a `Dual` part which holds both the primal value, paired with the partial derivative.
+It doesn't handle chunked-mode, or perturbation confusion.
+The overload generation hook in this example is: `define_dual_overload`.
+
+````@eval
+using Markdown
+Markdown.parse("""
+```julia
+$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String))
+```
+""")
+````
diff --git a/docs/src/examples/reverse_mode.md b/docs/src/examples/reverse_mode.md
new file mode 100644
index 0000000..97dcabd
--- /dev/null
+++ b/docs/src/examples/reverse_mode.md
@@ -0,0 +1,16 @@
+# ReverseDiffZero
+
+This is a fairly standard operator overloading based reverse mode AD system.
+It defines a `Tracked` type which carries the primal value as well as a reference to the tape which is it using, a partially accumulated partial derivative and a `propagate` function that propagates its partial back to its input.
+A perhaps unusual thing about it is how little it carries around its creating operator's inputs.
+That information is all entirely wrapped up in the `propagate` function.
+The overload generation hook in this example is: `define_tracked_overload`.
+
+````@eval
+using Markdown
+Markdown.parse("""
+```julia
+$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String))
+```
+""")
+````
diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/index.md
similarity index 82%
rename from docs/src/autodiff/operator_overloading.md
rename to docs/src/index.md
index 56f72e1..d29896f 100644
--- a/docs/src/autodiff/operator_overloading.md
+++ b/docs/src/index.md
@@ -1,4 +1,8 @@
-# Operator Overloading
+# Operator Overloading AD with ChainRulesOverloadGeneration.jl
+
+The ChainRulesOverloadGeneration package provides a suite of methods for using [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) rules in operator overloaded based AD systems.
+It tracks what rules are defined at any point in time, and lets you trigger functions to which can use `@eval` in order to define the matching operator overloads.
+
The principal interface for using the operator overload generation method is [`on_new_rule`](@ref).
This function allows one to register a hook to be run every time a new rule is defined.
@@ -14,7 +18,7 @@ or more simply you can just use conditions for this.
For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write:
```julia
const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}}
-function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F
+function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}}) where F
@eval quote
# ...
end
@@ -53,28 +57,4 @@ When the rules are refreshed (automatically or manually), the hooks are only tri
`clear_new_rule_hooks!`(@ref) clears all registered hooks.
It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function.
-## Examples
-
-### ForwardDiffZero
-The overload generation hook in this example is: `define_dual_overload`.
-````@eval
-using Markdown
-Markdown.parse("""
-```julia
-$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String))
-```
-""")
-````
-
-### ReverseDiffZero
-The overload generation hook in this example is: `define_tracked_overload`.
-
-````@eval
-using Markdown
-Markdown.parse("""
-```julia
-$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String))
-```
-""")
-````
diff --git a/src/ChainRulesOverloadGeneration.jl b/src/ChainRulesOverloadGeneration.jl
new file mode 100644
index 0000000..8804f6b
--- /dev/null
+++ b/src/ChainRulesOverloadGeneration.jl
@@ -0,0 +1,15 @@
+module ChainRulesOverloadGeneration
+
+using ChainRulesCore
+
+export on_new_rule, refresh_rules
+
+include("ruleset_loading.jl")
+include("precompile.jl")
+
+function __init__()
+ # Need to refresh rules when a package is loaded
+ push!(Base.package_callbacks, _package_hook)
+end
+
+end # module
diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl
index aa40fae..876b8b8 100644
--- a/src/ruleset_loading.jl
+++ b/src/ruleset_loading.jl
@@ -1,9 +1,5 @@
# Infastructure to support generating overloads from rules.
_package_hook(::Base.PkgId) = refresh_rules()
-function __init__()
- # Need to refresh rules when a package is loaded
- push!(Base.package_callbacks, _package_hook)
-end
# Holds all the hook functions that are invokes when a new rule is defined
const RRULE_DEFINITION_HOOKS = Function[]
@@ -59,9 +55,10 @@ Returns a list of all the methods of the currently defined rules of the given ki
Excluding the fallback rule that returns `nothing` for every input.
"""
function _rule_list end
-# The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them
-_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__)
+_rule_list(rule_kind) = (m for m in methods(rule_kind) if !_is_fallback(rule_kind, m))
+"check if this is the fallback-frule/rrule that always returns `nothing`"
+_is_fallback(rule_kind, m::Method) = m.sig === Tuple{typeof(rule_kind), Any, Vararg{Any}}
const LAST_REFRESH_RRULE = Ref(0)
const LAST_REFRESH_FRULE = Ref(0)
diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl
index 13a1b19..3283dd3 100644
--- a/test/demos/forwarddiffzero.jl
+++ b/test/demos/forwarddiffzero.jl
@@ -1,6 +1,11 @@
"The simplest viable forward mode a AD, only supports `Float64`"
module ForwardDiffZero
using ChainRulesCore
+using ChainRulesOverloadGeneration
+# resolve conflicts while this code exists in both.
+const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
+const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
+
using Test
#########################################
diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl
index 540b8ea..a410ad5 100644
--- a/test/demos/reversediffzero.jl
+++ b/test/demos/reversediffzero.jl
@@ -1,6 +1,11 @@
"The simplest viable reverse mode a AD, only supports `Float64`"
module ReverseDiffZero
using ChainRulesCore
+using ChainRulesOverloadGeneration
+# resolve conflicts while this code exists in both.
+const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
+const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
+
using Test
#########################################
@@ -14,7 +19,7 @@ struct Tracked{F} <: Real
propagate::F
primal::Float64
tape::Vector{Tracked} # a reference to a shared tape
- partial::Base.RefValue{Float64} # current accumulated sensitivity
+ partial::Base.RefValue{Float64} # current accumulated sensitivity
end
"An intermediate value, a Branch in Nabla terms."
@@ -24,15 +29,15 @@ function Tracked(propagate, primal, tape)
return v
end
-"Marker for inputs (leaves) that don't need to propagate."
-struct NoPropagate end
-
"An input, a Leaf in Nabla terms. No inputs of its own to propagate to."
function Tracked(primal, tape)
# don't actually need to put these on the tape, since they don't need to propagate
return Tracked(NoPropagate(), primal, tape, Ref(zero(primal)))
end
+"Marker for inputs (leaves) that don't need to propagate."
+struct NoPropagate end
+
primal(d::Tracked) = d.primal
primal(d) = d
diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl
index e1743f0..4f3b02d 100644
--- a/test/ruleset_loading.jl
+++ b/test/ruleset_loading.jl
@@ -10,7 +10,7 @@
op = sig.parameters[1]
push!(rrule_history, op)
end
-
+
@testset "new rules hit the hooks" begin
# Now define some rules
@scalar_rule x + y (1, 1)
@@ -22,8 +22,8 @@
end
@testset "# Make sure nothing happens anymore once we clear the hooks" begin
- ChainRulesCore.clear_new_rule_hooks!(frule)
- ChainRulesCore.clear_new_rule_hooks!(rrule)
+ ChainRulesOverloadGeneration.clear_new_rule_hooks!(frule)
+ ChainRulesOverloadGeneration.clear_new_rule_hooks!(rrule)
old_frule_history = copy(frule_history)
old_rrule_history = copy(rrule_history)
@@ -34,11 +34,11 @@
@test old_rrule_history == rrule_history
@test old_frule_history == frule_history
end
-
end
+
@testset "_primal_sig" begin
- _primal_sig = ChainRulesCore._primal_sig
+ _primal_sig = ChainRulesOverloadGeneration._primal_sig
@testset "frule" begin
@test isequal( # DataType without shared type but with constraint
_primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), Int, Vector{Int}}),
@@ -69,4 +69,10 @@
)
end
end
+
+ @testset "_is_fallback" begin
+ _is_fallback = ChainRulesOverloadGeneration._is_fallback
+ @test _is_fallback(rrule, first(methods(rrule, (Nothing,))))
+ @test _is_fallback(frule, first(methods(frule, (Nothing,))))
+ end
end
diff --git a/test/runtests.jl b/test/runtests.jl
new file mode 100644
index 0000000..8ddd53f
--- /dev/null
+++ b/test/runtests.jl
@@ -0,0 +1,16 @@
+using ChainRulesCore
+using ChainRulesOverloadGeneration
+# resolve conflicts while this code exists in both.
+const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
+const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
+
+using Test
+
+@testset "ChainRulesCore" begin
+ include("ruleset_loading.jl")
+
+ @testset "demos" begin
+ include("demos/forwarddiffzero.jl")
+ include("demos/reversediffzero.jl")
+ end
+end