From 8b5f8792cd47924afd0ecc8f5daa2c1bb583876f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 28 May 2021 14:51:59 +0100 Subject: [PATCH 1/6] update compat and bump version --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 976687d7..39736948 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.13" +version = "0.6.14" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "0.9.44" +ChainRulesCore = "0.9.44, 0.10.0" Compat = "3" FiniteDifferences = "0.12" julia = "1" From 6df6060935f41c02dd0d5d646eb47d3174a6340a Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 28 May 2021 14:52:12 +0100 Subject: [PATCH 2/6] replace NO_FIELDS by NoTangent() --- docs/src/index.md | 2 +- src/testers.jl | 6 +++--- test/deprecated.jl | 16 ++++++++-------- test/testers.jl | 46 +++++++++++++++++++++++----------------------- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index be6cc83f..f784ab17 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -37,7 +37,7 @@ and `rrule` function ChainRulesCore.rrule(::typeof(two2three), x1, x2) y = two2three(x1, x2) function two2three_pullback(Ȳ) - return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3]) + return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3]) end return y, two2three_pullback end diff --git a/src/testers.jl b/src/testers.jl index 10091a97..56707fc3 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -104,9 +104,9 @@ function test_frule( xs = primal.(xẋs) ẋs = tangent.(xẋs) if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...) - _test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) + _test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) end - res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) + res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) res === nothing && throw(MethodError(frule, typeof((f, xs...)))) res isa Tuple || error("The frule should return (y, ∂y), not $res.") Ω_ad, dΩ_ad = res @@ -187,7 +187,7 @@ function test_rrule( ∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.") ∂self = ∂s[1] x̄s_ad = ∂s[2:end] - @test ∂self === NO_FIELDS # No internal fields + @test ∂self === NoTangent() # No internal fields # Correctness testing via finite differencing. # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 diff --git a/test/deprecated.jl b/test/deprecated.jl index 8ae25333..bb206f0d 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -30,7 +30,7 @@ end end function ChainRulesCore.rrule(::typeof(identity), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return x, identity_pullback end @@ -50,7 +50,7 @@ end # define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative # in the rrule function ChainRulesCore.rrule(::typeof(sinconj), x) - sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ) + sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ) return sin(x), sinconj_pullback end @@ -63,7 +63,7 @@ end ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx) function ChainRulesCore.rrule(::typeof(fst), x, y) function fst_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return x, fst_pullback end @@ -80,7 +80,7 @@ end @testset "single input, multiple output" begin simo(x) = (x, 2x) function ChainRulesCore.rrule(simo, x) - simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b) + simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b) return simo(x), simo_pullback end function ChainRulesCore.frule((_, ẋ), simo, x) @@ -104,7 +104,7 @@ end ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx)) function ChainRulesCore.rrule(::typeof(first), x::Tuple) function first_pullback(Δx) - return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x)-1)...)) + return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x)-1)...)) end return first(x), first_pullback end @@ -135,7 +135,7 @@ end ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx) function ChainRulesCore.rrule(::typeof(fsymtest), x, s) function fsymtest_pullback(Δx) - return NO_FIELDS, Δx, NoTangent() + return NoTangent(), Δx, NoTangent() end return x, fsymtest_pullback end @@ -157,7 +157,7 @@ end end function ChainRulesCore.rrule(::typeof(futestkws), x; err = true) function futestkws_pullback(Δx) - return (NO_FIELDS, Δx) + return (NoTangent(), Δx) end return futestkws(x; err = err), futestkws_pullback end @@ -194,7 +194,7 @@ end end function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err = true) function fbtestkws_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return fbtestkws(x, y; err = err), fbtestkws_pullback end diff --git a/test/testers.jl b/test/testers.jl index 8fa3f948..43655b31 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -40,7 +40,7 @@ end end function ChainRulesCore.rrule(::typeof(identity), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return x, identity_pullback end @@ -64,7 +64,7 @@ end function ChainRulesCore.rrule(::typeof(identity), x::Array) function identity_pullback(ȳ) x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> (inplace_used=true; ā .+= ȳ)) - return (NO_FIELDS, x̄_ret) + return (NoTangent(), x̄_ret) end return identity(x), identity_pullback end @@ -90,7 +90,7 @@ end function my_identity_pullback(ȳ) # only the in-place part is incorrect x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ) - return (NO_FIELDS, x̄_ret) + return (NoTangent(), x̄_ret) end return my_identity(x), my_identity_pullback end @@ -103,7 +103,7 @@ end @testset "check inferred" begin ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable), x) = (x, Δx) function ChainRulesCore.rrule(::typeof(f_inferrable), x) - f_inferrable_pullback(Δy) = (NO_FIELDS, Δy) + f_inferrable_pullback(Δy) = (NoTangent(), Δy) return x, f_inferrable_pullback end @@ -120,7 +120,7 @@ end return (x, x > 0 ? Float64(Δx) : Float32(Δx)) end function ChainRulesCore.rrule(::typeof(f_noninferrable_frule), x) - f_noninferrable_frule_pullback(Δy) = (NO_FIELDS, Δy) + f_noninferrable_frule_pullback(Δy) = (NoTangent(), Δy) return x, f_noninferrable_frule_pullback end @@ -141,10 +141,10 @@ end ChainRulesCore.frule((_, Δx), ::typeof(f_noninferrable_rrule), x) = (x, Δx) function ChainRulesCore.rrule(::typeof(f_noninferrable_rrule), x) if x > 0 - f_noninferrable_rrule_pullback(Δy) = (NO_FIELDS, Δy) + f_noninferrable_rrule_pullback(Δy) = (NoTangent(), Δy) return x, f_noninferrable_rrule_pullback else - return x, _ -> (NO_FIELDS, Δy) # this is not hit by the used point + return x, _ -> (NoTangent(), Δy) # this is not hit by the used point end end @@ -163,7 +163,7 @@ end @testset "check not inferred in pullback" begin function ChainRulesCore.rrule(::typeof(f_noninferrable_pullback), x) - f_noninferrable_pullback_pullback(Δy) = (NO_FIELDS, x > 0 ? Float64(Δy) : Float32(Δy)) + f_noninferrable_pullback_pullback(Δy) = (NoTangent(), x > 0 ? Float64(Δy) : Float32(Δy)) return x, f_noninferrable_pullback_pullback end test_rrule(f_noninferrable_pullback, 2.0; check_inferred = false) @@ -177,7 +177,7 @@ end function ChainRulesCore.rrule(::typeof(f_noninferrable_thunk), x, y) function f_noninferrable_thunk_pullback(Δz) ∂x = @thunk(x > 0 ? Float64(Δz) : Float32(Δz)) - return (NO_FIELDS, ∂x, Δz) + return (NoTangent(), ∂x, Δz) end return x + y, f_noninferrable_thunk_pullback end @@ -193,7 +193,7 @@ end return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx)) end function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x) - f_inferrable_pullback_only_pullback(Δy) = (NO_FIELDS, oftype(x, Δy)) + f_inferrable_pullback_only_pullback(Δy) = (NoTangent(), oftype(x, Δy)) return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback end test_frule(f_inferrable_pullback_only, 2.0; check_inferred = true) @@ -207,7 +207,7 @@ end # define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative # in the rrule function ChainRulesCore.rrule(::typeof(sinconj), x) - sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ) + sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ) return sin(x), sinconj_pullback end @@ -220,7 +220,7 @@ end ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx) function ChainRulesCore.rrule(::typeof(fst), x, y) function fst_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return x, fst_pullback end @@ -237,7 +237,7 @@ end @testset "single input, multiple output" begin simo(x) = (x, 2x) function ChainRulesCore.rrule(simo, x) - simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b) + simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b) return simo(x), simo_pullback end function ChainRulesCore.frule((_, ẋ), simo, x) @@ -260,7 +260,7 @@ end ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx)) function ChainRulesCore.rrule(::typeof(first), x::Tuple) function first_pullback(Δx) - return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x)-1)...)) + return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x)-1)...)) end return first(x), first_pullback end @@ -291,7 +291,7 @@ end ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx) function ChainRulesCore.rrule(::typeof(fsymtest), x, s) function fsymtest_pullback(Δx) - return NO_FIELDS, Δx, NoTangent() + return NoTangent(), Δx, NoTangent() end return x, fsymtest_pullback end @@ -311,7 +311,7 @@ end end function ChainRulesCore.rrule(::typeof(futestkws), x; err = true) function futestkws_pullback(Δx) - return (NO_FIELDS, Δx) + return (NoTangent(), Δx) end return futestkws(x; err = err), futestkws_pullback end @@ -345,7 +345,7 @@ end end function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err = true) function fbtestkws_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return fbtestkws(x, y; err = err), fbtestkws_pullback end @@ -374,7 +374,7 @@ end function ChainRulesCore.rrule(::typeof(primalapprox), x) function primalapprox_pullback(Δx) - return (NO_FIELDS, Δx) + return (NoTangent(), Δx) end return x + sqrt(eps(x)), primalapprox_pullback end @@ -445,7 +445,7 @@ end Base.IteratorSize(iter), Base.IteratorEltype(iter), ) - return (NO_FIELDS, ∂iter) + return (NoTangent(), ∂iter) end return iterfun(iter), iterfun_pullback end @@ -466,7 +466,7 @@ end end function ChainRulesCore.rrule(::typeof(my_identity1), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return 2.5 * x, identity_pullback end @@ -482,7 +482,7 @@ end end function ChainRulesCore.rrule(::typeof(my_identity2), x) function identity_pullback(ȳ) - return (NO_FIELDS, 31.8 * ȳ) + return (NoTangent(), 31.8 * ȳ) end return x, identity_pullback end @@ -501,7 +501,7 @@ end rev_trouble((x,y)) = y function ChainRulesCore.rrule(::typeof(rev_trouble), (x,y)::P) where P - rev_trouble_pullback(ȳ) = (NO_FIELDS, Tangent{P}(ZeroTangent(), ȳ)) + rev_trouble_pullback(ȳ) = (NoTangent(), Tangent{P}(ZeroTangent(), ȳ)) return y, rev_trouble_pullback end test_rrule(rev_trouble, (3, 3.0) ⊢ Tangent{Tuple{Int, Float64}}(ZeroTangent(), 1.0)) @@ -513,7 +513,7 @@ end function foo_pullback(Δy) da = zeros(size(a)) da[i] = Δy - return NO_FIELDS, da, ZeroTangent() + return NoTangent(), da, ZeroTangent() end return foo(a, i), foo_pullback end From 14c6b90a2109e3a440b054b958d0cee1fc73a7fe Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 28 May 2021 14:54:00 +0100 Subject: [PATCH 3/6] compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 39736948..4891f5d9 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "0.9.44, 0.10.0" +ChainRulesCore = "0.10.0" Compat = "3" FiniteDifferences = "0.12" julia = "1" From b359b3b09e19b73ad346fdb98ad8e8a09aa8f638 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 28 May 2021 14:54:58 +0100 Subject: [PATCH 4/6] version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4891f5d9..b876e73c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.14" +version = "0.7.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 420937c88aa1649765b3025b2919314eac3720ff Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 15:43:47 +0100 Subject: [PATCH 5/6] fix docs manifest --- docs/Manifest.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 8a223d01..29f1b06e 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" +git-tree-sha1 = "5d64be50ea9b43a89b476be773e125cef03c7cd5" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.44" +version = "0.10.1" [[ChainRulesTestUtils]] deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] path = ".." uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.13" +version = "0.7.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -57,9 +57,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[FiniteDifferences]] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"] -git-tree-sha1 = "8662836e29702fdfdb1b90cbe4162e31b94f1e51" +git-tree-sha1 = "f8c8e287c1d68abc2719ad58fb39de9f6c0d71b1" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.7" +version = "0.12.10" [[IOCapture]] deps = ["Logging"] @@ -167,9 +167,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "a1f226ebe197578c25fcf948bfff3d0d12f2ff20" +git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.2.1" +version = "1.2.2" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] From 0be20aa042727edaf867d30c4cedbf7ed9c09dba Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Jun 2021 17:30:23 +0100 Subject: [PATCH 6/6] cleaner compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b876e73c..c74cac02 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "0.10.0" +ChainRulesCore = "0.10" Compat = "3" FiniteDifferences = "0.12" julia = "1"