From 276eead90f1e16b21a7eec2c55d523e43da4ca5f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 9 Sep 2021 17:01:48 -0400 Subject: [PATCH 1/5] add RuleConfig --- src/runtime.jl | 1 + src/stage1/forward.jl | 4 +++- src/stage1/generated.jl | 6 +++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/runtime.jl b/src/runtime.jl index 09f6e3e6..9aecb195 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -1,4 +1,5 @@ using ChainRulesCore +struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}} end @Base.constprop :aggressive accum(a, b) = a + b @Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 3812eee7..5585d0e9 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -13,7 +13,7 @@ first_partial(x::CompositeBundle) = map(first_partial, getfield(x, :tup)) # TODO: Which version do we want in ChainRules? function my_frule(args::ATB{1}...) - frule(map(first_partial, args), map(primal, args)...) + frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...) end # Fast path for some hot cases @@ -118,6 +118,8 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end end +ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, args...) = ∂☆internal{1}()(args...) + function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N} ∂☆p = ∂☆{minus1(N)}() ∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 4e75945a..7602927b 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -210,7 +210,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N} if N == 1 # Base case (inlined to avoid ambiguities with manually specified # higher order rules) - z = rrule(f, args...) + z = rrule(DiffractorRuleConfig(), f, args...) if z === nothing return ∂⃖recurse{1}()(f, args...) end @@ -226,6 +226,10 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end end +function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T} + ∂⃖{1}()(f, args...) +end + @Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...) return rrule(Core.apply_type, head, args...) end From c8cc0532db675f1577a68fd7da44e5d3c1e937ca Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 9 Sep 2021 22:14:36 -0400 Subject: [PATCH 2/5] add tests, frule is broken --- test/runtests.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5e410556..91e1413f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,8 @@ using Diffractor -using Diffractor: var"'", ∂⃖ +using Diffractor: var"'", ∂⃖, DiffractorRuleConfig using ChainRules using ChainRulesCore -using ChainRules: ZeroTangent, NoTangent +using ChainRules: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad using Symbolics using LinearAlgebra @@ -201,4 +201,13 @@ loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w) x = rand(10, 10) @test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x) isa Tuple{Matrix{Float64}} +# PR # 45 - Calling back into AD from ChainRules +y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2.0) +@test y45 ≈ 2.0 +@test back45(1) == (ZeroTangent(), 1.0) + +# z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0) +# @test z45 ≈ 2.0 +# @test delta45 ≈ 1.0 + include("pinn.jl") From fc3bf404f3072cfefe0dec8928db9e887fb1a237 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 10 Sep 2021 00:06:52 -0400 Subject: [PATCH 3/5] closer? --- src/stage1/forward.jl | 7 ++++++- test/runtests.jl | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 5585d0e9..1a8adfb0 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -118,7 +118,12 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end end -ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, args...) = ∂☆internal{1}()(args...) +function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) + tangents = map(partials, args) do p, a + TangentBundle{1}(a, (p,)) + end + ∂☆internal{1}()(tangents...) +end function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N} ∂☆p = ∂☆{minus1(N)}() diff --git a/test/runtests.jl b/test/runtests.jl index 91e1413f..8d443265 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -206,8 +206,8 @@ y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2.0) @test y45 ≈ 2.0 @test back45(1) == (ZeroTangent(), 1.0) -# z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0) -# @test z45 ≈ 2.0 -# @test delta45 ≈ 1.0 +z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0) +@test z45 ≈ 2.0 +@test delta45 ≈ 1.0 include("pinn.jl") From f16cb00a8ac0491d9ba15a119a6d325c889ab3c4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 10 Sep 2021 10:48:38 -0400 Subject: [PATCH 4/5] fixup --- src/stage1/forward.jl | 7 +++---- test/runtests.jl | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index 1a8adfb0..2edcccf4 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -119,10 +119,9 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...) end function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...) - tangents = map(partials, args) do p, a - TangentBundle{1}(a, (p,)) - end - ∂☆internal{1}()(tangents...) + bundles = map((p,a) -> TangentBundle{1}(a, (p,)), partials, args) + result = ∂☆internal{1}()(bundles...) + primal(result), first_partial(result) end function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N} diff --git a/test/runtests.jl b/test/runtests.jl index 8d443265..404e999e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -198,15 +198,15 @@ end # PR #43 loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w) -x = rand(10, 10) -@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x) isa Tuple{Matrix{Float64}} +x43 = rand(10, 10) +@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} # PR # 45 - Calling back into AD from ChainRules -y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2.0) +y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) @test y45 ≈ 2.0 @test back45(1) == (ZeroTangent(), 1.0) -z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0) +z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test z45 ≈ 2.0 @test delta45 ≈ 1.0 From 48460e719a2b50e0365f692c47c7692ddb1994ab Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 10 Sep 2021 19:50:43 -0400 Subject: [PATCH 5/5] tweak import --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 404e999e..50a9039b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Diffractor using Diffractor: var"'", ∂⃖, DiffractorRuleConfig using ChainRules using ChainRulesCore -using ChainRules: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad +using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad using Symbolics using LinearAlgebra