From c89318449adc7094c7ca6bf70bd52c66ae140834 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Jul 2022 20:08:40 -0400 Subject: [PATCH 1/3] handle Base.tail & friens --- Project.toml | 2 +- src/tangent_types/abstract_zero.jl | 4 ++++ src/tangent_types/tangent.jl | 7 +++++++ src/tangent_types/thunks.jl | 4 ++++ test/tangent_types/abstract_zero.jl | 4 ++++ test/tangent_types/tangent.jl | 7 +++++++ test/tangent_types/thunks.jl | 7 +++++++ 7 files changed, 34 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bc1ec7b45..3b97c30ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.2" +version = "1.15.3" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index e52d84819..986fc9854 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -16,6 +16,10 @@ Base.iszero(::AbstractZero) = true Base.iterate(x::AbstractZero) = (x, nothing) Base.iterate(::AbstractZero, ::Any) = nothing +Base.first(x::AbstractZero) = x +Base.tail(x::AbstractZero) = x +Base.last(x::AbstractZero) = x + Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index f187cb3f2..1e8914dba 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -96,6 +96,12 @@ function Base.show(io::IO, tangent::Tangent{P}) where {P} end end +Base.first(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = first(backing(canonicalize(tangent))) +Base.last(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = last(backing(canonicalize(tangent))) + +Base.tail(tangent::Tangent{P}) where {P<:Tuple} = Tangent{_tailtype(P)}(Base.tail(backing(tangent))...) +@generated _tailtype(::Type{P}) where {P<:Tuple} = Tuple{P.parameters[2:end]...} + function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) @@ -127,6 +133,7 @@ end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) + Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) function Base.reverse(tangent::Tangent) rev_backing = reverse(backing(tangent)) diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 3735307b0..8baa006e8 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -24,6 +24,10 @@ end return element, (underlying_object, new_state) end +Base.first(x::AbstractThunk) = first(unthunk(x)) +Base.last(x::AbstractThunk) = last(unthunk(x)) +Base.tail(x::AbstractThunk) = Base.tail(unthunk(x)) + Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b) Base.:(-)(a::AbstractThunk) = -unthunk(a) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 43e433b6c..e3d8642e4 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -86,6 +86,10 @@ @test z[1:3] === z @test z[1, 2] === z @test getindex(z) === z + + @test first(z) === z + @test last(z) === z + @test Base.tail(z) === z end @testset "NoTangent" begin diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 176dd4985..e37a45a68 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -78,6 +78,13 @@ end @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + + tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) + @test @inferred(first(tang3)) === tang3[1] === 1.0 + @test @inferred(last(tang3)) isa Thunk + @test unthunk(last(tang3)) == [7.0] + @test Tuple(@inferred Base.tail(tang3))[1] === NoTangent() + @test Tuple(Base.tail(tang3))[end] isa Thunk NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index 4b3ab4eb4..68c2cc53a 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -16,6 +16,13 @@ @test nothing === iterate(@thunk ()) == iterate(()) end + + @testset "first, last, tail" begin + @test first(@thunk (1,2,3) .+ 4) === 5 + @test last(@thunk (1,2,3) .+ 4) === 7 + @test Base.tail(@thunk (1,2,3) .+ 4) === (6, 7) + @test Base.tail(@thunk NoTangent() * 5) === NoTangent() + end @testset "show" begin rep = repr(Thunk(rand)) From 652d88c24fac48079bfdd286e29ae65b0ae960c4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 19 Jul 2022 07:42:06 -0400 Subject: [PATCH 2/3] empty cases --- src/tangent_types/tangent.jl | 4 +++- test/tangent_types/tangent.jl | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index 1e8914dba..6d83596f5 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -99,8 +99,10 @@ end Base.first(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = first(backing(canonicalize(tangent))) Base.last(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = last(backing(canonicalize(tangent))) -Base.tail(tangent::Tangent{P}) where {P<:Tuple} = Tangent{_tailtype(P)}(Base.tail(backing(tangent))...) +Base.tail(t::Tangent{P}) where {P<:Tuple} = Tangent{_tailtype(P)}(Base.tail(backing(canonicalize(t)))...) @generated _tailtype(::Type{P}) where {P<:Tuple} = Tuple{P.parameters[2:end]...} +Base.tail(t::Tangent{<:Tuple{Any}}) = NoTangent() +Base.tail(t::Tangent{<:Tuple{}}) = NoTangent() function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} back = backing(canonicalize(tangent)) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index e37a45a68..96946288b 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -72,12 +72,15 @@ end end @test Tangent{Foo}(; x=2.5).x == 2.5 - @test keys(Tangent{Tuple{Float64}}(2.0)) == Base.OneTo(1) + tang1 = Tangent{Tuple{Float64}}(2.0) + @test keys(tang1) == Base.OneTo(1) @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test NoTangent() === @inferred Base.tail(tang1) + @test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}()) tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) @test @inferred(first(tang3)) === tang3[1] === 1.0 From 3db6428c79d7307ea92b5c9356187daffaab3f41 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 19 Jul 2022 07:42:27 -0400 Subject: [PATCH 3/3] tail on NamedTuples too --- src/tangent_types/tangent.jl | 5 +++++ test/tangent_types/tangent.jl | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index 6d83596f5..d13e75c76 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -104,6 +104,11 @@ Base.tail(t::Tangent{P}) where {P<:Tuple} = Tangent{_tailtype(P)}(Base.tail(back Base.tail(t::Tangent{<:Tuple{Any}}) = NoTangent() Base.tail(t::Tangent{<:Tuple{}}) = NoTangent() +Base.tail(t::Tangent{P}) where {P<:NamedTuple} = Tangent{_tailtype(P)}(; Base.tail(backing(canonicalize(t)))...) +_tailtype(::Type{NamedTuple{S,P}}) where {S,P} = NamedTuple{Base.tail(S), _tailtype(P)} +Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{Any}}}) = NoTangent() +Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{}}}) = NoTangent() + function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} back = backing(canonicalize(tangent)) return unthunk(getfield(back, idx)) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 96946288b..004dd71cf 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -99,6 +99,14 @@ end @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk + @test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0 + @test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent + + ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2))) + @test ntang1 isa Tangent{<:NamedTuple{(:b,)}} + @test NoTangent() === @inferred Base.tail(ntang1) # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 if VERSION >= v"1.8-"