diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index 8cf6bb1..139a131 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -15,13 +15,7 @@ jobs: julia-version: [1.5] os: [ubuntu-latest] package: - - {user: JuliaDiff, repo: ChainRules.jl} - - {user: JuliaMath, repo: SpecialFunctions.jl} - - {user: invenia, repo: BlockDiagonals.jl} - - {user: invenia, repo: PDMatsExtras.jl} - - {user: chrisbrahms, repo: Hankel.jl} - - {user: SciML, repo: DiffEqBase.jl} - - {user: dfdx, repo: Yota.jl} + # - {user: Invenia, repo: Nabla.jl} steps: - uses: actions/checkout@v2 @@ -43,7 +37,7 @@ jobs: # force it to use this PR's version of the package Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps Pkg.update() - Pkg.test() # resolver may fail with test time deps + Pkg.test() # resolver may fail with test time deps catch err err isa Pkg.Resolve.ResolverError || rethrow() # If we can't resolve that means this is incompatible by SemVer and this is fine diff --git a/LICENSE.md b/LICENSE.md index d8cddc9..956cad7 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,7 +1,7 @@ -The ChainRulesCore.jl package is licensed under the MIT "Expat" License: +The ChainRulesOverloadGeneration.jl package is licensed under the MIT "Expat" License: > Copyright (c) 2018-2019: Jarrett Revels, and other JuliaDiff Contributors: -> https://github.com/JuliaDiff/ChainRulesCore.jl/contributors +> https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/contributors > > Permission is hereby granted, free of charge, to any person obtaining a copy > of this software and associated documentation files (the "Software"), to deal diff --git a/Project.toml b/Project.toml new file mode 100644 index 0000000..ed060f4 --- /dev/null +++ b/Project.toml @@ -0,0 +1,15 @@ +name = "ChainRulesOverloadGeneration" +uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" +version = "0.1.0" + +[deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[compat] +ChainRulesCore = "0.9" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test"] diff --git a/README.md b/README.md index 79b3cdc..6d4ada9 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,16 @@ -# ChainRulesCore +# ChainRulesOverloadGeneration -[![Build Status](https://github.com/JuliaDiff/ChainRulesCore.jl/workflows/CI/badge.svg)](https://github.com/JuliaDiff/ChainRulesCore.jl/actions?query=workflow:CI) -[![Coverage](https://codecov.io/gh/JuliaDiff/ChainRulesCore.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaDiff/ChainRulesCore.jl) +[![Build Status](https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/workflows/CI/badge.svg)](https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/actions?query=workflow:CI) +[![Coverage](https://codecov.io/gh/JuliaDiff/ChainRulesOverloadGeneration.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaDiff/ChainRulesOverloadGeneration.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![DOI](https://zenodo.org/badge/199721843.svg)](https://zenodo.org/badge/latestdoi/199721843) **Docs:** -[![](https://img.shields.io/badge/docs-master-blue.svg)](https://juliadiff.org/ChainRulesCore.jl/dev) -[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/ChainRulesCore.jl/stable) +[![](https://img.shields.io/badge/docs-master-blue.svg)](https://juliadiff.org/ChainRulesOverloadGeneration.jl/dev) +[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/ChainRulesOverloadGeneration.jl/stable) -The ChainRulesCore package provides a light-weight dependency for defining sensitivities for functions in your packages, without you needing to depend on ChainRules itself. - -This will allow your package to be used with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), which aims to provide a variety of common utilities that can be used by downstream automatic differentiation (AD) tools to define and execute forward-, reverse-, and mixed-mode primitives. - -This package is a work in progress; PRs welcome! +The ChainRulesOverloadGeneration package provides a suite of methods for using [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) rules in operator overloading AD systems. +It tracks what rules are defined at any point in time, and lets you trigger functions to which can use `@eval` in order to define the matching operator overloads. diff --git a/docs/Manifest.toml b/docs/Manifest.toml index d0f0169..3b5c092 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -11,15 +11,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -path = ".." +git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "0.9.44" +[[ChainRulesOverloadGeneration]] +deps = ["ChainRulesCore"] +path = ".." +uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" +version = "0.1.0" + [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "0900bc19193b8e672d9cd477e6cd92d9e7c02f99" +git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.29.0" +version = "3.30.0" [[Dates]] deps = ["Printf"] diff --git a/docs/Project.toml b/docs/Project.toml index cab507d..f3eff3a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/make.jl b/docs/make.jl index 0115025..5fa74e7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,72 +1,32 @@ -using ChainRulesCore +using ChainRulesOverloadGeneration +using ChainRulesCore: ChainRulesCore using Documenter using DocThemeIndigo using Markdown -DocMeta.setdocmeta!( - ChainRulesCore, - :DocTestSetup, - quote - using Random - Random.seed!(0) # frule doctest shows output - - using ChainRulesCore - # These rules are all actually defined in ChainRules.jl, but we redefine them here to - # avoid the dependency. - @scalar_rule(sin(x), cos(x)) # frule and rrule doctest - @scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest - @scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest - end -) - -indigo = DocThemeIndigo.install(ChainRulesCore) +indigo = DocThemeIndigo.install(ChainRulesOverloadGeneration) makedocs( - modules=[ChainRulesCore], + modules=[ChainRulesOverloadGeneration], format=Documenter.HTML( prettyurls=false, assets=[indigo], - mathengine=MathJax3( - Dict( - :tex => Dict( - "inlineMath" => [["\$","\$"], ["\\(","\\)"]], - "tags" => "ams", - # TODO: remove when using physics package - "macros" => Dict( - "ip" => ["{\\left\\langle #1, #2 \\right\\rangle}", 2], - "Re" => "{\\operatorname{Re}}", - "Im" => "{\\operatorname{Im}}", - "tr" => "{\\operatorname{tr}}", - ), - ), - ), - ), ), - sitename="ChainRules", - authors="Jarrett Revels and other contributors", + sitename="ChainRules Overload Generation", + authors="Lyndon White and other contributors", pages=[ "Introduction" => "index.md", - "FAQ" => "FAQ.md", - "Writing Good Rules" => "writing_good_rules.md", - "Complex Numbers" => "complex.md", - "Deriving Array Rules" => "arrays.md", - "Debug Mode" => "debug_mode.md", - "Gradient Accumulation" => "gradient_accumulation.md", - "Usage in AD" => [ - "Overview" => "autodiff/overview.md", - "Operator Overloading" => "autodiff/operator_overloading.md", - ], - "Design" => [ - "Changing the Primal" => "design/changing_the_primal.md", - "Many Differential Types" => "design/many_differentials.md", + "Examples of making AD systems" => [ + "Forward Mode" => "examples/forward_mode.md", + "Reverse Mode" => "examples/reverse_mode.md", ], "API" => "api.md", - ], + ], strict=true, checkdocs=:exports, ) deploydocs( - repo = "github.com/JuliaDiff/ChainRulesCore.jl.git", + repo = "github.com/JuliaDiff/ChainRulesOverloadGeneration.jl.git", push_preview=true, ) diff --git a/docs/src/api.md b/docs/src/api.md index d6db494..e3c80e0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,49 +1,12 @@ # API Documentation -## Rules ```@autodocs -Modules = [ChainRulesCore] -Pages = ["rules.jl"] -Private = false -``` - -## Rule Definition Tools -```@autodocs -Modules = [ChainRulesCore] -Pages = ["rule_definition_tools.jl"] -Private = false -``` - -## Differentials -```@autodocs -Modules = [ChainRulesCore] -Pages = [ - "differentials/abstract_zero.jl", - "differentials/one.jl", - "differentials/composite.jl", - "differentials/thunks.jl", - "differentials/abstract_differential.jl", - "differentials/notimplemented.jl", -] -Private = false -``` - -## Accumulation -```@docs -add!! -ChainRulesCore.is_inplaceable_destination -``` - -## Ruleset Loading -```@autodocs -Modules = [ChainRulesCore] +Modules = [ChainRulesOverloadGeneration] Pages = ["ruleset_loading.jl"] Private = false ``` ## Internal ```@docs -ChainRulesCore.AbstractTangent -ChainRulesCore.debug_mode -ChainRulesCore.clear_new_rule_hooks! +ChainRulesOverloadGeneration.clear_new_rule_hooks! ``` diff --git a/docs/src/assets/logo.png b/docs/src/assets/logo.png new file mode 100644 index 0000000..e6ebdcd Binary files /dev/null and b/docs/src/assets/logo.png differ diff --git a/docs/src/assets/logo.svg b/docs/src/assets/logo.svg new file mode 100644 index 0000000..1e8f0ad --- /dev/null +++ b/docs/src/assets/logo.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/src/examples/forward_mode.md b/docs/src/examples/forward_mode.md new file mode 100644 index 0000000..31a632e --- /dev/null +++ b/docs/src/examples/forward_mode.md @@ -0,0 +1,14 @@ +# ForwardDiffZero +This is a fairly standard operator overloading-based forward mode AD system. +It defines a `Dual` part which holds both the primal value, paired with the partial derivative. +It doesn't handle chunked-mode, or perturbation confusion. +The overload generation hook in this example is: `define_dual_overload`. + +````@eval +using Markdown +Markdown.parse(""" +```julia +$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) +``` +""") +```` diff --git a/docs/src/examples/reverse_mode.md b/docs/src/examples/reverse_mode.md new file mode 100644 index 0000000..97dcabd --- /dev/null +++ b/docs/src/examples/reverse_mode.md @@ -0,0 +1,16 @@ +# ReverseDiffZero + +This is a fairly standard operator overloading based reverse mode AD system. +It defines a `Tracked` type which carries the primal value as well as a reference to the tape which is it using, a partially accumulated partial derivative and a `propagate` function that propagates its partial back to its input. +A perhaps unusual thing about it is how little it carries around its creating operator's inputs. +That information is all entirely wrapped up in the `propagate` function. +The overload generation hook in this example is: `define_tracked_overload`. + +````@eval +using Markdown +Markdown.parse(""" +```julia +$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) +``` +""") +```` diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/index.md similarity index 82% rename from docs/src/autodiff/operator_overloading.md rename to docs/src/index.md index 56f72e1..d29896f 100644 --- a/docs/src/autodiff/operator_overloading.md +++ b/docs/src/index.md @@ -1,4 +1,8 @@ -# Operator Overloading +# Operator Overloading AD with ChainRulesOverloadGeneration.jl + +The ChainRulesOverloadGeneration package provides a suite of methods for using [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) rules in operator overloaded based AD systems. +It tracks what rules are defined at any point in time, and lets you trigger functions to which can use `@eval` in order to define the matching operator overloads. + The principal interface for using the operator overload generation method is [`on_new_rule`](@ref). This function allows one to register a hook to be run every time a new rule is defined. @@ -14,7 +18,7 @@ or more simply you can just use conditions for this. For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write: ```julia const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}} -function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F +function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}}) where F @eval quote # ... end @@ -53,28 +57,4 @@ When the rules are refreshed (automatically or manually), the hooks are only tri `clear_new_rule_hooks!`(@ref) clears all registered hooks. It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function. -## Examples - -### ForwardDiffZero -The overload generation hook in this example is: `define_dual_overload`. -````@eval -using Markdown -Markdown.parse(""" -```julia -$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) -``` -""") -```` - -### ReverseDiffZero -The overload generation hook in this example is: `define_tracked_overload`. - -````@eval -using Markdown -Markdown.parse(""" -```julia -$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) -``` -""") -```` diff --git a/src/ChainRulesOverloadGeneration.jl b/src/ChainRulesOverloadGeneration.jl new file mode 100644 index 0000000..8804f6b --- /dev/null +++ b/src/ChainRulesOverloadGeneration.jl @@ -0,0 +1,15 @@ +module ChainRulesOverloadGeneration + +using ChainRulesCore + +export on_new_rule, refresh_rules + +include("ruleset_loading.jl") +include("precompile.jl") + +function __init__() + # Need to refresh rules when a package is loaded + push!(Base.package_callbacks, _package_hook) +end + +end # module diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index aa40fae..876b8b8 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -1,9 +1,5 @@ # Infastructure to support generating overloads from rules. _package_hook(::Base.PkgId) = refresh_rules() -function __init__() - # Need to refresh rules when a package is loaded - push!(Base.package_callbacks, _package_hook) -end # Holds all the hook functions that are invokes when a new rule is defined const RRULE_DEFINITION_HOOKS = Function[] @@ -59,9 +55,10 @@ Returns a list of all the methods of the currently defined rules of the given ki Excluding the fallback rule that returns `nothing` for every input. """ function _rule_list end -# The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them -_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) +_rule_list(rule_kind) = (m for m in methods(rule_kind) if !_is_fallback(rule_kind, m)) +"check if this is the fallback-frule/rrule that always returns `nothing`" +_is_fallback(rule_kind, m::Method) = m.sig === Tuple{typeof(rule_kind), Any, Vararg{Any}} const LAST_REFRESH_RRULE = Ref(0) const LAST_REFRESH_FRULE = Ref(0) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 13a1b19..3283dd3 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -1,6 +1,11 @@ "The simplest viable forward mode a AD, only supports `Float64`" module ForwardDiffZero using ChainRulesCore +using ChainRulesOverloadGeneration +# resolve conflicts while this code exists in both. +const on_new_rule = ChainRulesOverloadGeneration.on_new_rule +const refresh_rules = ChainRulesOverloadGeneration.refresh_rules + using Test ######################################### diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index 540b8ea..a410ad5 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -1,6 +1,11 @@ "The simplest viable reverse mode a AD, only supports `Float64`" module ReverseDiffZero using ChainRulesCore +using ChainRulesOverloadGeneration +# resolve conflicts while this code exists in both. +const on_new_rule = ChainRulesOverloadGeneration.on_new_rule +const refresh_rules = ChainRulesOverloadGeneration.refresh_rules + using Test ######################################### @@ -14,7 +19,7 @@ struct Tracked{F} <: Real propagate::F primal::Float64 tape::Vector{Tracked} # a reference to a shared tape - partial::Base.RefValue{Float64} # current accumulated sensitivity + partial::Base.RefValue{Float64} # current accumulated sensitivity end "An intermediate value, a Branch in Nabla terms." @@ -24,15 +29,15 @@ function Tracked(propagate, primal, tape) return v end -"Marker for inputs (leaves) that don't need to propagate." -struct NoPropagate end - "An input, a Leaf in Nabla terms. No inputs of its own to propagate to." function Tracked(primal, tape) # don't actually need to put these on the tape, since they don't need to propagate return Tracked(NoPropagate(), primal, tape, Ref(zero(primal))) end +"Marker for inputs (leaves) that don't need to propagate." +struct NoPropagate end + primal(d::Tracked) = d.primal primal(d) = d diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index e1743f0..4f3b02d 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -10,7 +10,7 @@ op = sig.parameters[1] push!(rrule_history, op) end - + @testset "new rules hit the hooks" begin # Now define some rules @scalar_rule x + y (1, 1) @@ -22,8 +22,8 @@ end @testset "# Make sure nothing happens anymore once we clear the hooks" begin - ChainRulesCore.clear_new_rule_hooks!(frule) - ChainRulesCore.clear_new_rule_hooks!(rrule) + ChainRulesOverloadGeneration.clear_new_rule_hooks!(frule) + ChainRulesOverloadGeneration.clear_new_rule_hooks!(rrule) old_frule_history = copy(frule_history) old_rrule_history = copy(rrule_history) @@ -34,11 +34,11 @@ @test old_rrule_history == rrule_history @test old_frule_history == frule_history end - end + @testset "_primal_sig" begin - _primal_sig = ChainRulesCore._primal_sig + _primal_sig = ChainRulesOverloadGeneration._primal_sig @testset "frule" begin @test isequal( # DataType without shared type but with constraint _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), Int, Vector{Int}}), @@ -69,4 +69,10 @@ ) end end + + @testset "_is_fallback" begin + _is_fallback = ChainRulesOverloadGeneration._is_fallback + @test _is_fallback(rrule, first(methods(rrule, (Nothing,)))) + @test _is_fallback(frule, first(methods(frule, (Nothing,)))) + end end diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 0000000..8ddd53f --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,16 @@ +using ChainRulesCore +using ChainRulesOverloadGeneration +# resolve conflicts while this code exists in both. +const on_new_rule = ChainRulesOverloadGeneration.on_new_rule +const refresh_rules = ChainRulesOverloadGeneration.refresh_rules + +using Test + +@testset "ChainRulesCore" begin + include("ruleset_loading.jl") + + @testset "demos" begin + include("demos/forwarddiffzero.jl") + include("demos/reversediffzero.jl") + end +end