From 63ab3e13a8670827ad3d38bb20b9c9111882c3dc Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 12:10:59 +0100 Subject: [PATCH 1/8] error inside Tangent constructor if incorrect type is used --- src/tangent_types/tangent.jl | 16 ++++++++++++++++ test/tangent_types/tangent.jl | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index bb91e431e..f5f1b9e14 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -25,6 +25,22 @@ struct Tangent{P,T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain differentials) backing::T + + function Tangent{P,T}(backing) where {P,T} + function backing_error(P, G, E) + msg = "Tangent for the primal $P should be backed by a $E type, not by $G." + throw(ArgumentError(msg)) + end + + if P <: Tuple + T <: Tuple || backing_error(P, T, Tuple) + elseif P <: AbstractDict + T <: AbstractDict || backing_error(P, T, AbstractDict) + else + T <: NamedTuple || backing_error(P, T, NamedTuple) + end + return new(backing) + end end function Tangent{P}(; kwargs...) where {P} diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 26e6a2422..0ae79ae33 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -23,6 +23,26 @@ end @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} end + @testset "constructor" begin + t = (1.0, 2.0) + nt = (x = 1, y=2.0) + d = Dict(:x => 1.0, :y => 2.0) + vals = [1, 2] + + @test_throws ArgumentError Tangent{typeof(t), typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(t), typeof(d)}(d) + + @test_throws ArgumentError Tangent{typeof(d), typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(d), typeof(t)}(t) + + @test_throws ArgumentError Tangent{typeof(nt), typeof(vals)}(vals) + @test_throws ArgumentError Tangent{typeof(nt), typeof(d)}(d) + @test_throws ArgumentError Tangent{typeof(nt), typeof(t)}(t) + + @test_throws ArgumentError Tangent{Foo, typeof(d)}(d) + @test_throws ArgumentError Tangent{Foo, typeof(t)}(t) + end + @testset "==" begin @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) From ade4b5928bd39acf93d5e68a2a95c329783fa07d Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 12:11:16 +0100 Subject: [PATCH 2/8] v1.10.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3c284ca23..ad96fa04e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.10.0" +version = "1.10.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From e40c5011d1ee53abc594506e739cd77335f2bc02 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 12:55:51 +0100 Subject: [PATCH 3/8] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/tangent.jl | 2 +- test/tangent_types/tangent.jl | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index f5f1b9e14..9ecdf25cb 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -29,7 +29,7 @@ struct Tangent{P,T} <: AbstractTangent function Tangent{P,T}(backing) where {P,T} function backing_error(P, G, E) msg = "Tangent for the primal $P should be backed by a $E type, not by $G." - throw(ArgumentError(msg)) + return throw(ArgumentError(msg)) end if P <: Tuple diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 0ae79ae33..e9846e8f0 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -25,22 +25,22 @@ end @testset "constructor" begin t = (1.0, 2.0) - nt = (x = 1, y=2.0) + nt = (x=1, y=2.0) d = Dict(:x => 1.0, :y => 2.0) vals = [1, 2] - @test_throws ArgumentError Tangent{typeof(t), typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(t), typeof(d)}(d) + @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) - @test_throws ArgumentError Tangent{typeof(d), typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(d), typeof(t)}(t) + @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) - @test_throws ArgumentError Tangent{typeof(nt), typeof(vals)}(vals) - @test_throws ArgumentError Tangent{typeof(nt), typeof(d)}(d) - @test_throws ArgumentError Tangent{typeof(nt), typeof(t)}(t) + @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) + @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) + @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) - @test_throws ArgumentError Tangent{Foo, typeof(d)}(d) - @test_throws ArgumentError Tangent{Foo, typeof(t)}(t) + @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) + @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) end @testset "==" begin From 9153237e5fb227b2b741a561712f17e061822978 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 13:05:55 +0100 Subject: [PATCH 4/8] take out _backing_error function --- src/tangent_types/tangent.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index 9ecdf25cb..a58057c09 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -27,17 +27,12 @@ struct Tangent{P,T} <: AbstractTangent backing::T function Tangent{P,T}(backing) where {P,T} - function backing_error(P, G, E) - msg = "Tangent for the primal $P should be backed by a $E type, not by $G." - return throw(ArgumentError(msg)) - end - if P <: Tuple - T <: Tuple || backing_error(P, T, Tuple) + T <: Tuple || _backing_error(P, T, Tuple) elseif P <: AbstractDict - T <: AbstractDict || backing_error(P, T, AbstractDict) - else - T <: NamedTuple || backing_error(P, T, NamedTuple) + T <: AbstractDict || _backing_error(P, T, AbstractDict) + else # Any other struct (including NamedTuple) + T <: NamedTuple || _backing_error(P, T, NamedTuple) end return new(backing) end @@ -61,6 +56,11 @@ function Tangent{P}(d::Dict) where {P<:Dict} return Tangent{P,typeof(d)}(d) end +function _backing_error(P, G, E) + msg = "Tangent for the primal $P should be backed by a $E type, not by $G." + throw(ArgumentError(msg)) +end + function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end From ee8b62df14364ff3d51fc72faf37b1b7f2700985 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 13:12:22 +0100 Subject: [PATCH 5/8] Update src/tangent_types/tangent.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index a58057c09..f434fcfd6 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -58,7 +58,7 @@ end function _backing_error(P, G, E) msg = "Tangent for the primal $P should be backed by a $E type, not by $G." - throw(ArgumentError(msg)) + return throw(ArgumentError(msg)) end function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} From 51c08212b06f0fdf3a2c305e7fc8d1cd31c4bdb1 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 13:26:59 +0100 Subject: [PATCH 6/8] fix wrong test --- test/tangent_types/tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index e9846e8f0..3186aa91b 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -130,7 +130,7 @@ end @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{Foo,typeof(d)}(d) + cdict = Tangent{typeof(d),typeof(d)}(d) @test_throws MethodError reverse(Tangent{Foo}()) end From 1ce68f89cd417f78eb09f430dd58ce9d4528bb34 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 14:30:59 +0100 Subject: [PATCH 7/8] allow Any to be backed by anything --- src/tangent_types/tangent.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index f434fcfd6..c71f044e8 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -31,6 +31,7 @@ struct Tangent{P,T} <: AbstractTangent T <: Tuple || _backing_error(P, T, Tuple) elseif P <: AbstractDict T <: AbstractDict || _backing_error(P, T, AbstractDict) + elseif P == Any # can be anything else # Any other struct (including NamedTuple) T <: NamedTuple || _backing_error(P, T, NamedTuple) end From d1a9f5b8355a7ddd6fc2dd7dff92f31a2d3622de Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Oct 2021 16:19:23 +0100 Subject: [PATCH 8/8] Update src/tangent_types/tangent.jl Co-authored-by: Lyndon White --- src/tangent_types/tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index c71f044e8..ec7c64448 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -31,7 +31,7 @@ struct Tangent{P,T} <: AbstractTangent T <: Tuple || _backing_error(P, T, Tuple) elseif P <: AbstractDict T <: AbstractDict || _backing_error(P, T, AbstractDict) - elseif P == Any # can be anything + elseif P === Any # can be anything else # Any other struct (including NamedTuple) T <: NamedTuple || _backing_error(P, T, NamedTuple) end