diff --git a/Project.toml b/Project.toml index 3c284ca23..ad96fa04e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.10.0" +version = "1.10.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index bb91e431e..ec7c64448 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -25,6 +25,18 @@ struct Tangent{P,T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T + + function Tangent{P,T}(backing) where {P,T} + if P <: Tuple + T <: Tuple || _backing_error(P, T, Tuple) + elseif P <: AbstractDict + T <: AbstractDict || _backing_error(P, T, AbstractDict) + elseif P === Any # can be anything + else # Any other struct (including NamedTuple) + T <: NamedTuple || _backing_error(P, T, NamedTuple) + end + return new(backing) + end end function Tangent{P}(; kwargs...) where {P} @@ -45,6 +57,11 @@ function Tangent{P}(d::Dict) where {P<:Dict} return Tangent{P,typeof(d)}(d) end +function _backing_error(P, G, E) + msg = "Tangent for the primal $P should be backed by a $E type, not by $G." + return throw(ArgumentError(msg)) +end + function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 26e6a2422..3186aa91b 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -23,6 +23,26 @@ end @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} end + @testset "constructor" begin + t = (1.0, 2.0) + nt = (x=1, y=2.0) + d = Dict(:x => 1.0, :y => 2.0) + vals = [1, 2] + + @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) + + @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) + + @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) + @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) + @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) + + @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) + @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) + end + @testset "==" begin @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) @@ -110,7 +130,7 @@ end @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo,typeof(d)}(d) + cdict = Tangent{typeof(d),typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end