diff --git a/src/runtime.jl b/src/runtime.jl index 09f6e3e6..9aecb195 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -1,4 +1,5 @@ using ChainRulesCore +struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}} end @Base.constprop :aggressive accum(a, b) = a + b @Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 3812eee7..2edcccf4 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -13,7 +13,7 @@ first_partial(x::CompositeBundle) = map(first_partial, getfield(x, :tup)) # TODO: Which version do we want in ChainRules? function my_frule(args::ATB{1}...) - frule(map(first_partial, args), map(primal, args)...) + frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...) end # Fast path for some hot cases @@ -118,6 +118,12 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end end +function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) + bundles = map((p,a) -> TangentBundle{1}(a, (p,)), partials, args) + result = ∂☆internal{1}()(bundles...) + primal(result), first_partial(result) +end + function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N} ∂☆p = ∂☆{minus1(N)}() ∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 4e75945a..7602927b 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -210,7 +210,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N} if N == 1 # Base case (inlined to avoid ambiguities with manually specified # higher order rules) - z = rrule(f, args...) + z = rrule(DiffractorRuleConfig(), f, args...) if z === nothing return ∂⃖recurse{1}()(f, args...) end @@ -226,6 +226,10 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end end +function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T} + ∂⃖{1}()(f, args...) +end + @Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...) return rrule(Core.apply_type, head, args...) end diff --git a/test/runtests.jl b/test/runtests.jl index 5e410556..50a9039b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,8 @@ using Diffractor -using Diffractor: var"'", ∂⃖ +using Diffractor: var"'", ∂⃖, DiffractorRuleConfig using ChainRules using ChainRulesCore -using ChainRules: ZeroTangent, NoTangent +using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad using Symbolics using LinearAlgebra @@ -198,7 +198,16 @@ end # PR #43 loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w) -x = rand(10, 10) -@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x) isa Tuple{Matrix{Float64}} +x43 = rand(10, 10) +@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} + +# PR # 45 - Calling back into AD from ChainRules +y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) +@test y45 ≈ 2.0 +@test back45(1) == (ZeroTangent(), 1.0) + +z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) +@test z45 ≈ 2.0 +@test delta45 ≈ 1.0 include("pinn.jl")