From bea1b810a608e2984b55b26bfcd2587e68ea3793 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 16:04:13 +0800 Subject: [PATCH 1/3] add frules for getfield --- src/rulesets/Base/indexing.jl | 9 ++++++--- test/rulesets/Base/indexing.jl | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 1334cc925..37ed8ca48 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -1,6 +1,10 @@ # Int rather than Int64/Integer is intentional -function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int) - return x.i, ẋ.i +function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}) + return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) +end + +function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds) + return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) end "for a given tuple type, returns a Val{N} where N is the length of the tuple" @@ -21,7 +25,6 @@ function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Nu dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T)) return (NoTangent(), Tangent{T}(dx...), NoTangent()) end - return x[i], getindex_back_2 end # Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector), diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index d3c7ecfb4..c21bb8425 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,3 +1,17 @@ +@testset "getfield" begin + struct Foo + x::Float64 + y::Float64 + end + test_frule(getfield, Foo(1.5, 2.5), :x, check_inferred=false) + + test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false) + test_frule(getfield, (; a=1.5, b=2.5), 2) + + test_frule(getfield, (1.5, 2.5), 2) + test_frule(getfield, (1.5, 2.5), 2, true) +end + @testset "getindex" begin @testset "getindex(::Tuple, ...)" begin x = (1.2, 3.4, 5.6) From 84cd7be44fb0493aeb18168bf6aca59fc5885d65 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 22:14:11 +0800 Subject: [PATCH 2/3] move struct to top level --- test/rulesets/Base/indexing.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index c21bb8425..a677df3b9 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,9 +1,11 @@ +struct FooTwoField + x::Float64 + y::Float64 +end + + @testset "getfield" begin - struct Foo - x::Float64 - y::Float64 - end - test_frule(getfield, Foo(1.5, 2.5), :x, check_inferred=false) + test_frule(getfield, FooTwoField(1.5, 2.5), :x, check_inferred=false) test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false) test_frule(getfield, (; a=1.5, b=2.5), 2) From a76e0ae4a536e1e35d3024d50a940cf7c8a78df2 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Sep 2023 14:49:14 +0800 Subject: [PATCH 3/3] undo mistakenly deleted line --- src/rulesets/Base/indexing.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 37ed8ca48..2f5e6cf79 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -25,6 +25,7 @@ function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Nu dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T)) return (NoTangent(), Tangent{T}(dx...), NoTangent()) end + return x[i], getindex_back_2 end # Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector),