diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 1334cc925..2f5e6cf79 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -1,6 +1,10 @@ # Int rather than Int64/Integer is intentional -function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int) - return x.i, ẋ.i +function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}) + return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) +end + +function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds) + return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) end "for a given tuple type, returns a Val{N} where N is the length of the tuple" diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index d3c7ecfb4..a677df3b9 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,3 +1,19 @@ +struct FooTwoField + x::Float64 + y::Float64 +end + + +@testset "getfield" begin + test_frule(getfield, FooTwoField(1.5, 2.5), :x, check_inferred=false) + + test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false) + test_frule(getfield, (; a=1.5, b=2.5), 2) + + test_frule(getfield, (1.5, 2.5), 2) + test_frule(getfield, (1.5, 2.5), 2, true) +end + @testset "getindex" begin @testset "getindex(::Tuple, ...)" begin x = (1.2, 3.4, 5.6)