From b548b4e65649e9da3e4e67fa6c3c3d6340587c7e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:56:37 -0400 Subject: [PATCH 1/3] accumulate NamedTuple + Tangent --- src/runtime.jl | 14 ++++++++++++++ test/runtests.jl | 10 +--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/runtime.jl b/src/runtime.jl index 9aecb195..86e0d1f3 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -13,3 +13,17 @@ end @Base.constprop :aggressive accum(a::NoTangent, b) = b @Base.constprop :aggressive accum(a, b::NoTangent) = a @Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent() + +using ChainRulesCore: Tangent, backing + +function accum(x::Tangent{T}, y::NamedTuple) where T + # @warn "gradient is both a Tangent and a NamedTuple" x y + z = accum(backing(x), y) + Tangent{T,typeof(z)}(z) +end +accum(x::NamedTuple, y::Tangent) = accum(y, x) + +function accum(x::Tangent{T}, y::Tangent) where T + z = accum(backing(x), backing(y)) + Tangent{T,typeof(z)}(z) +end diff --git a/test/runtests.jl b/test/runtests.jl index 086b4ee6..b0014bf5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -162,15 +162,7 @@ end # Make sure that there's no infinite recursion in kwarg calls g_kw(;x=1.0) = sin(x) f_kw(x) = g_kw(;x) -@test bwd(f_kw)(1.0) == bwd(sin)(1.0) broken=true -#= -MethodError: no method matching +(::Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}, ::Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}) -... - [2] elementwise_add(a::NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}, b::NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}) - @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/tangent.jl:287 - [3] +(a::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}}, b::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}}) - @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:130 -=# +@test bwd(f_kw)(1.0) == bwd(sin)(1.0) function f_crit_edge(a, b, c, x) # A function with two critical edges. This used to trigger an issue where From d50f6e8c0a0533da1a80c40ccc01a17a0509a42f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 31 Aug 2022 18:45:49 -0400 Subject: [PATCH 2/3] fixup --- src/extra_rules.jl | 3 +++ src/runtime.jl | 18 +++++++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/extra_rules.jl b/src/extra_rules.jl index 56ce436e..cd70a6df 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -266,3 +266,6 @@ end function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, add!!, val) val, Δ->(NoTangent(), NoTangent(), Δ) end + +Base.real(z::ZeroTangent) = z # TODO should be in CRC +Base.real(z::NoTangent) = z diff --git a/src/runtime.jl b/src/runtime.jl index 86e0d1f3..e3c1edec 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -5,25 +5,25 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}} @Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b) @Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple) fnames = union(fieldnames(x), fieldnames(y)) + isempty(fnames) && return :((;)) # code below makes () instead gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent()) grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent()) Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...) end @Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...) -@Base.constprop :aggressive accum(a::NoTangent, b) = b -@Base.constprop :aggressive accum(a, b::NoTangent) = a -@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent() +@Base.constprop :aggressive accum(a::AbstractZero, b) = b +@Base.constprop :aggressive accum(a, b::AbstractZero) = a +@Base.constprop :aggressive accum(a::AbstractZero, b::AbstractZero) = NoTangent() using ChainRulesCore: Tangent, backing function accum(x::Tangent{T}, y::NamedTuple) where T # @warn "gradient is both a Tangent and a NamedTuple" x y - z = accum(backing(x), y) - Tangent{T,typeof(z)}(z) + _tangent(T, accum(backing(x), y)) end accum(x::NamedTuple, y::Tangent) = accum(y, x) +# This solves an ambiguity, but also avoids Tangent{ZeroTangent}() which + does not: +accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing(y))) -function accum(x::Tangent{T}, y::Tangent) where T - z = accum(backing(x), backing(y)) - Tangent{T,typeof(z)}(z) -end +_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z) +_tangent(::Type, ::NamedTuple{()}) = NoTangent() From c9f2a90b7673f008cfb7723e7517fa79654ebc6e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 4 Sep 2022 14:49:15 -0400 Subject: [PATCH 3/3] don't test on 1.8 --- .github/workflows/CI.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7781617f..ea181ba6 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -17,7 +17,9 @@ jobs: matrix: version: - '1.7' # Lowest claimed support in Project.toml - - '1' # Latest Release + # - '1' # Latest Release # Testing on 1.8 gives this message: + # ┌ Warning: ir verification broken. Either use 1.9 or 1.7 + # └ @ Diffractor ~/work/Diffractor.jl/Diffractor.jl/src/stage1/recurse.jl:889 - 'nightly' os: - ubuntu-latest