From 54365f8ca0a060cfc926a5c88bb005a75d238f08 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 01:59:10 +0200 Subject: [PATCH 1/8] ruleconfig support and Zygote tests --- Project.toml | 2 ++ src/AbstractDifferentiation.jl | 4 ++++ src/ruleconfig.jl | 21 +++++++++++++++++++++ test/ruleconfig.jl | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+) create mode 100644 src/ruleconfig.jl create mode 100644 test/ruleconfig.jl diff --git a/Project.toml b/Project.toml index d5d742b..91446c2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["Mohamed Tarek and contributors"] version = "0.4.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] +ChainRulesCore = "1" Compat = "3" ExprTools = "0.1" ForwardDiff = "0.10" diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index deebc2a..dc53ce5 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -643,11 +643,15 @@ end @inline asarray(x) = [x] @inline asarray(x::AbstractArray) = x +include("ruleconfig.jl") function __init__() @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl") @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl") @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl") @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl") + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) + end end end diff --git a/src/ruleconfig.jl b/src/ruleconfig.jl new file mode 100644 index 0000000..d9f9a20 --- /dev/null +++ b/src/ruleconfig.jl @@ -0,0 +1,21 @@ +using ChainRulesCore: RuleConfig, rrule_via_ad + +""" + ReverseRuleConfigBackend + +AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package. +""" +struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode + ruleconfig::RC +end + +AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) + return (vs) -> begin + _, back = rrule_via_ad(ab.ruleconfig, f, xs...) + if vs isa Tuple && length(vs) === 1 + return Base.tail(back(vs[1])) + else + return Base.tail(back(vs)) + end + end +end diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl new file mode 100644 index 0000000..61bcdbf --- /dev/null +++ b/test/ruleconfig.jl @@ -0,0 +1,33 @@ +using AbstractDifferentiation +using Test +using Zygote, Yota + +@testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin + backends = [@inferred(AD.ZygoteBackend())] + @testset for backend in backends + @testset "Derivative" begin + test_derivatives(backend) + end + @testset "Gradient" begin + test_gradients(backend) + end + @testset "Jacobian" begin + test_jacobians(backend) + end + @testset "jvp" begin + test_jvp(backend) + end + @testset "j′vp" begin + test_j′vp(backend) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(backend) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(backend) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(backend) + end + end +end From 0a0d8154c4ed1a659af55f50c34a896c0b96c141 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 02:22:49 +0200 Subject: [PATCH 2/8] actually run the tests :) --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 0348523..d79bafc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,5 @@ using Test include("reversediff.jl") include("finitedifferences.jl") include("tracker.jl") + include("ruleconfig.jl") end From ce1c0c3b69c0c4c7e2fbe2168c3f0b376034fb4b Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 02:44:41 +0200 Subject: [PATCH 3/8] avoid loading Yota --- test/ruleconfig.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl index 61bcdbf..412d89c 100644 --- a/test/ruleconfig.jl +++ b/test/ruleconfig.jl @@ -1,6 +1,6 @@ using AbstractDifferentiation using Test -using Zygote, Yota +using Zygote @testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin backends = [@inferred(AD.ZygoteBackend())] From ebbd71b52e39815bf83f14ba721889d9668a3452 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Tue, 8 Feb 2022 02:58:39 +0200 Subject: [PATCH 4/8] lower bound Zygote compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 91446c2..b9e306f 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ ExprTools = "0.1" ForwardDiff = "0.10" Requires = "0.5, 1" ReverseDiff = "1" +Zygote = "0.6" julia = "1" [extras] From 2fef1a12d63451a82335828e9dbb2660ca3e2967 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Wed, 9 Feb 2022 01:16:12 +0200 Subject: [PATCH 5/8] move imports --- src/AbstractDifferentiation.jl | 5 ++++- src/ruleconfig.jl | 2 -- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index dc53ce5..12aadb9 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -1,6 +1,7 @@ module AbstractDifferentiation using LinearAlgebra, ExprTools, Requires, Compat +using ChainRulesCore: RuleConfig, rrule_via_ad export AD @@ -650,7 +651,9 @@ function __init__() @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl") @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl") @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) + @static if VERSION >= v"1.6" + ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) + end end end diff --git a/src/ruleconfig.jl b/src/ruleconfig.jl index d9f9a20..8174d7a 100644 --- a/src/ruleconfig.jl +++ b/src/ruleconfig.jl @@ -1,5 +1,3 @@ -using ChainRulesCore: RuleConfig, rrule_via_ad - """ ReverseRuleConfigBackend From 1dd97627912f6369f996bd7ebe2848fe63dc7361 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Wed, 9 Feb 2022 01:16:20 +0200 Subject: [PATCH 6/8] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b9e306f..249cba6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractDifferentiation" uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 07e8e337552708fa887475347df60e6df1de05c7 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Wed, 9 Feb 2022 01:16:43 +0200 Subject: [PATCH 7/8] avoid testing ruleconfig in 1.0 --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index d79bafc..b435dee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,5 +8,7 @@ using Test include("reversediff.jl") include("finitedifferences.jl") include("tracker.jl") - include("ruleconfig.jl") + @static if VERSION >= v"1.6" + include("ruleconfig.jl") + end end From 6207218503a228e63a2d5c324ba80f4a770cad0a Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Wed, 9 Feb 2022 01:19:28 +0200 Subject: [PATCH 8/8] remove Zygote compat --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 249cba6..c544dc8 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,6 @@ ExprTools = "0.1" ForwardDiff = "0.10" Requires = "0.5, 1" ReverseDiff = "1" -Zygote = "0.6" julia = "1" [extras]