diff --git a/Project.toml b/Project.toml index d5d742b..c544dc8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,17 @@ name = "AbstractDifferentiation" uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] +ChainRulesCore = "1" Compat = "3" ExprTools = "0.1" ForwardDiff = "0.10" diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index deebc2a..12aadb9 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -1,6 +1,7 @@ module AbstractDifferentiation using LinearAlgebra, ExprTools, Requires, Compat +using ChainRulesCore: RuleConfig, rrule_via_ad export AD @@ -643,11 +644,17 @@ end @inline asarray(x) = [x] @inline asarray(x::AbstractArray) = x +include("ruleconfig.jl") function __init__() @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl") @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl") @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl") @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl") + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + @static if VERSION >= v"1.6" + ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) + end + end end end diff --git a/src/ruleconfig.jl b/src/ruleconfig.jl new file mode 100644 index 0000000..8174d7a --- /dev/null +++ b/src/ruleconfig.jl @@ -0,0 +1,19 @@ +""" + ReverseRuleConfigBackend + +AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package. +""" +struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode + ruleconfig::RC +end + +AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) + return (vs) -> begin + _, back = rrule_via_ad(ab.ruleconfig, f, xs...) + if vs isa Tuple && length(vs) === 1 + return Base.tail(back(vs[1])) + else + return Base.tail(back(vs)) + end + end +end diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl new file mode 100644 index 0000000..412d89c --- /dev/null +++ b/test/ruleconfig.jl @@ -0,0 +1,33 @@ +using AbstractDifferentiation +using Test +using Zygote + +@testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin + backends = [@inferred(AD.ZygoteBackend())] + @testset for backend in backends + @testset "Derivative" begin + test_derivatives(backend) + end + @testset "Gradient" begin + test_gradients(backend) + end + @testset "Jacobian" begin + test_jacobians(backend) + end + @testset "jvp" begin + test_jvp(backend) + end + @testset "j′vp" begin + test_j′vp(backend) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(backend) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(backend) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(backend) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 0348523..b435dee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,7 @@ using Test include("reversediff.jl") include("finitedifferences.jl") include("tracker.jl") + @static if VERSION >= v"1.6" + include("ruleconfig.jl") + end end