diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 195e0b9..0dec8ec 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,5 +1,10 @@ # This file is machine-generated - editing it directly is not advised +[[ANSIColoredPrinters]] +git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" +uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" +version = "0.0.1" + [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -11,15 +16,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "0b0aa9d61456940511416b59a0e902c57b154956" +git-tree-sha1 = "f53ca8d41e4753c41cdafa6ec5f7ce914b34be54" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.12" +version = "0.10.13" [[ChainRulesTestUtils]] deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] path = ".." uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.7.13" +version = "1.0.0-DEV" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -46,10 +51,10 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.8.5" [[Documenter]] -deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48" +deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] +git-tree-sha1 = "95265abf7d7bf06dfdb8d58525a23ea5fb0bdeee" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.3" +version = "0.27.4" [[Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] @@ -57,9 +62,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[FiniteDifferences]] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"] -git-tree-sha1 = "12417e4754486a547d98d65293dc0fafdfcc0736" +git-tree-sha1 = "18761c465ef2e87d9091c0fefb61f70d532d4cc0" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.14" +version = "0.12.16" [[IOCapture]] deps = ["Logging", "Random"] @@ -167,9 +172,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "a43a7b58a6e7dc933b2fa2e0ca653ccf8bb8fd0e" +git-tree-sha1 = "1b9a0f17ee0adde9e538227de093467348992397" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.2.6" +version = "1.2.7" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] diff --git a/docs/src/api.md b/docs/src/api.md index 3886dfe..1d9790a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -4,9 +4,3 @@ Modules = [ChainRulesTestUtils] Private = false ``` - - -## Global Configuration -```@docs -ChainRulesTestUtils.enable_tangent_transform! -``` \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index 398e375..3adbd46 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -72,7 +72,7 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly ```jldoctest ex julia> test_rrule(two2three, 3.33, -7.77); Test Summary: | Pass Total -test_rrule: two2three on Float64,Float64 | 8 8 +test_rrule: two2three on Float64,Float64 | 9 9 ``` @@ -100,12 +100,12 @@ call. ```jldoctest ex julia> test_scalar(relu, 0.5); Test Summary: | Pass Total -test_scalar: relu at 0.5 | 10 10 +test_scalar: relu at 0.5 | 11 11 julia> test_scalar(relu, -0.5); Test Summary: | Pass Total -test_scalar: relu at -0.5 | 10 10 +test_scalar: relu at -0.5 | 11 11 ``` diff --git a/src/ChainRulesTestUtils.jl b/src/ChainRulesTestUtils.jl index 03b794e..7e92681 100644 --- a/src/ChainRulesTestUtils.jl +++ b/src/ChainRulesTestUtils.jl @@ -29,6 +29,4 @@ include("check_result.jl") include("rule_config.jl") include("finite_difference_calls.jl") include("testers.jl") - -include("deprecated.jl") end # module diff --git a/src/check_result.jl b/src/check_result.jl index b0541d5..b6b6e0b 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -143,7 +143,7 @@ function test_approx(actual::A, expected::E, msg=""; kwargs...) where {A,E} if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow throw(MethodError, test_approx, (actual, expected)) end - test_approx(c_actual, c_expected; kwargs...) + test_approx(c_actual, c_expected, msg; kwargs...) end end diff --git a/src/deprecated.jl b/src/deprecated.jl deleted file mode 100644 index e48c1c7..0000000 --- a/src/deprecated.jl +++ /dev/null @@ -1,93 +0,0 @@ -# TODO remove these in version 0.7 - -function Base.isapprox(a, b::Union{AbstractZero,AbstractThunk}; kwargs...) - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return isapprox(b, a; kwargs...) -end -function Base.isapprox(d_ad::AbstractThunk, d_fd; kwargs...) - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return isapprox(extern(d_ad), d_fd; kwargs...) -end -function Base.isapprox(d_ad::NoTangent, d_fd; kwargs...) - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return error("Tried to differentiate w.r.t. a `NoTangent`") -end -# Call `all` to handle the case where `ZeroTangent` is standing in for a non-scalar zero -function Base.isapprox(d_ad::ZeroTangent, d_fd; kwargs...) - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return all(isapprox.(extern(d_ad), d_fd; kwargs...)) -end - -isapprox_vec(a, b; kwargs...) = isapprox(first(to_vec(a)), first(to_vec(b)); kwargs...) -Base.isapprox(a, b::Tangent; kwargs...) = isapprox(b, a; kwargs...) -function Base.isapprox(d_ad::Tangent{<:Tuple}, d_fd::Tuple; kwargs...) - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return isapprox_vec(d_ad, d_fd; kwargs...) -end -function Base.isapprox( - d_ad::Tangent{P,<:Tuple}, d_fd::Tangent{P,<:Tuple}; kwargs... -) where {P<:Tuple} - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return isapprox_vec(d_ad, d_fd; kwargs...) -end - -function Base.isapprox( - d_ad::Tangent{P,<:NamedTuple{T}}, d_fd::Tangent{P,<:NamedTuple{T}}; kwargs... -) where {P,T} - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return isapprox_vec(d_ad, d_fd; kwargs...) -end - -# Must be for same primal -function Base.isapprox(d_ad::Tangent{P}, d_fd::Tangent{Q}; kwargs...) where {P,Q} - Base.depwarn( - "isapprox is deprecated on AbstractTangents and will be removed. " * - "Restructure testing code to use `ChainRulesTestUtils.test_approx` instead.", - :isapprox, - ) - return false -end - -############################################### - -# From when primal and tangent was passed as a tuple -@deprecate( - rrule_test(f, ȳ, inputs::Tuple{Any,Any}...; kwargs...), - test_rrule(f, ((x ⊢ dx) for (x, dx) in inputs)...; output_tangent=ȳ, kwargs...) -) - -@deprecate( - frule_test(f, inputs::Tuple{Any,Any}...; kwargs...), - test_frule(f, ((x ⊢ dx) for (x, dx) in inputs)...; kwargs...) -) - -# renamed -Base.@deprecate_binding check_equal test_approx diff --git a/src/global_config.jl b/src/global_config.jl index b8859d4..9e33033 100644 --- a/src/global_config.jl +++ b/src/global_config.jl @@ -1,28 +1,6 @@ const _fdm = central_fdm(5, 1; max_range=1e-2) const TEST_INFERRED = Ref(true) -const TRANSFORMS_TO_ALT_TANGENTS = Function[] # e.g. [x -> @thunk(x), _ -> ZeroTangent(), x -> rebasis(x)] - -""" - enable_tangent_transform!(Thunk) - -Adds a alt-tangent tranform to the list of default `tangent_transforms` for -[`test_frule`](@ref) and [`test_rrule`](@ref) to test. -This list of defaults is overwritten by the `tangent_transforms` keyword argument. - -!!! info "Transitional Feature" - ChainRulesCore v1.0 will require that all well-behaved rules work for a variety of - tangent representations. In turn, the corresponding release of ChainRulesTestUtils will - test all the different tangent representations by default. - At that stage `enable_tangent_transform!(Thunk)` will have no effect, as it will already - be enabled. - We provide this configuration as a transitional feature to help migrate your packages - one feature at a time, prior to the breaking release of ChainRulesTestUtils that will - enforce it. -""" -function enable_tangent_transform!(::Type{Thunk}) - push!(TRANSFORMS_TO_ALT_TANGENTS, x->@thunk(x)) - unique!(TRANSFORMS_TO_ALT_TANGENTS) -end +const TRANSFORMS_TO_ALT_TANGENTS = Function[x->@thunk(x)] # e.g. [_ -> ZeroTangent(), x -> rebasis(x)] "sets up TEST_INFERRED based ion enviroment variables" function init_test_inferred_setting!() diff --git a/src/testers.jl b/src/testers.jl index 33fd336..1bce4ce 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -80,12 +80,6 @@ end # Keyword Arguments - `output_tangent` tangent to test accumulation of derivatives against should be a differential for the output of `f`. Is set automatically if not provided. - - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that - transform the passed argument tangents into alternative tangents that should be tested. - Note that the alternative tangents are only tested for not erroring when passed to - frule. Testing for correctness using finite differencing can be done using a - separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness: - `test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`. - `fdm::FiniteDifferenceMethod`: the finite differencing method to use. - `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to `frule`). Used for testing gradients from AD systems. @@ -104,7 +98,6 @@ function test_frule( f, args...; output_tangent=Auto(), - tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS, fdm=_fdm, frule_f=ChainRulesCore.frule, check_inferred::Bool=true, @@ -136,41 +129,16 @@ function test_frule( Ω = call_on_copy(primals...) test_approx(Ω_ad, Ω; isapprox_kwargs...) - # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 - is_ignored = isa.(tangents, Union{Nothing,NoTangent}) - if any(tangents .== nothing) - Base.depwarn( - "test_frule(f, k ⊢ nothing) is deprecated, use " * - "test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks", - :test_frule, - ) - end - # Correctness testing via finite differencing. + is_ignored = isa.(tangents, NoTangent) dΩ_fd = _make_jvp_call(fdm, call_on_copy, Ω, primals, tangents, is_ignored) test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...) acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent _test_add!!_behaviour(acc, dΩ_ad; isapprox_kwargs...) - - # test that rules work for other tangents - _test_frule_alt_tangents( - call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc; - isapprox_kwargs... - ) end # top-level testset end -function _test_frule_alt_tangents( - call, frule_f, config, tangent_transforms, tangents, primals, acc; - isapprox_kwargs... -) - @testset "ȧrgs = $(_string_typeof(tsf.(tangents)))" for tsf in tangent_transforms - _, dΩ = call(frule_f, config, tsf.(tangents), primals...) - _test_add!!_behaviour(acc, dΩ; isapprox_kwargs...) - end -end - """ test_rrule([config::RuleConfig,] f, args...; kwargs...) @@ -185,12 +153,8 @@ end # Keyword Arguments - `output_tangent` the seed to propagate backward for testing (technically a cotangent). should be a differential for the output of `f`. Is set automatically if not provided. -- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that - transform the passed `output_tangent` into alternative tangents that should be tested. - Note that the alternative tangents are only tested for not erroring when passed to - rrule. Testing for correctness using finite differencing can be done using a - separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness: - `test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`. +- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the + output tangent to the pullback returns the same result. - `fdm::FiniteDifferenceMethod`: the finite differencing method to use. - `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`). Used for testing gradients from AD systems. @@ -209,7 +173,7 @@ function test_rrule( f, args...; output_tangent=Auto(), - tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS, + check_thunked_output_tangent=true, fdm=_fdm, rrule_f=ChainRulesCore.rrule, check_inferred::Bool=true, @@ -254,22 +218,13 @@ function test_rrule( ) # Correctness testing via finite differencing. - # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 - is_ignored = isa.(accum_cotangents, Union{Nothing, NoTangent}) - if any(accum_cotangents .== nothing) - Base.depwarn( - "test_rrule(f, k ⊢ nothing) is deprecated, use " * - "test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks", - :test_rrule, - ) - end - + is_ignored = isa.(accum_cotangents, NoTangent) fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored) for (accum_cotangent, ad_cotangent, fd_cotangent) in zip( accum_cotangents, ad_cotangents, fd_cotangents ) - if accum_cotangent isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113 + if accum_cotangent isa NoTangent # then we marked this argument as not differentiable @assert fd_cotangent === nothing # this is how `_make_j′vp_call` works ad_cotangent isa ZeroTangent && error( "The pullback in the rrule should use NoTangent()" * @@ -285,21 +240,10 @@ function test_rrule( end end - # test other tangents don't error when passed to the pullback - _test_rrule_alt_tangents(pullback, tangent_transforms, ȳ, accum_cotangents) - end # top-level testset -end - -function _test_rrule_alt_tangents( - pullback, tangent_transforms, ȳ, accum_cotangents; - isapprox_kwargs... -) - @testset "ȳ = $(_string_typeof(tsf(ȳ)))" for tsf in tangent_transforms - ad_cotangents = pullback(tsf(ȳ)) - for (accum_cotangent, ad_cotangent) in zip(accum_cotangents, ad_cotangents) - _test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...) + if check_thunked_output_tangent + test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:") end - end + end # top-level testset end """ diff --git a/test/deprecated.jl b/test/deprecated.jl deleted file mode 100644 index ef6603d..0000000 --- a/test/deprecated.jl +++ /dev/null @@ -1,248 +0,0 @@ -@testset "isapprox" begin - @testset "Tangent{Tuple}" begin - @testset "basic" begin - x_tup = (1.5, 2.5, 3.5) - x_comp = Tangent{typeof(x_tup)}(x_tup...) - @test x_comp ≈ x_tup - @test x_tup ≈ x_comp - @test x_comp ≈ x_comp - - @test_throws Exception x_comp ≈ collect(x_tup) - end - - @testset "different types" begin - # both of these are reasonable diffentials for the `Tuple{Int, Int}` primal - @test Tangent{Tuple{Int,Int}}(1.0f0, 2.0f0) ≈ Tangent{Tuple{Int,Int}}(1.0, 2.0) - - D = Diagonal(randn(5)) - @test Tangent{typeof(D)}(; diag=D.diag) ≈ Tangent{typeof(D)}(; diag=D.diag) - - # But these have different primals so should not be equal - @test !( - Tangent{Tuple{Int,Int}}(1.0, 2.0) ≈ - Tangent{Tuple{Float64,Float64}}(1.0, 2.0) - ) - end - end -end - -@testset "old testers.jl" begin - @testset "unary: identity(x)" begin - function ChainRulesCore.frule((_, ẏ), ::typeof(identity), x) - return x, ẏ - end - function ChainRulesCore.rrule(::typeof(identity), x) - function identity_pullback(ȳ) - return (NoTangent(), ȳ) - end - return x, identity_pullback - end - @testset "frule_test" begin - frule_test(identity, (randn(), randn())) - frule_test(identity, (randn(4), randn(4))) - end - @testset "rrule_test" begin - rrule_test(identity, randn(), (randn(), randn())) - rrule_test(identity, randn(4), (randn(4), randn(4))) - end - end - - @testset "test derivative conjugated in pullback" begin - ChainRulesCore.frule((_, Δx), ::typeof(sinconj), x) = (sin(x), cos(x) * Δx) - - # define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative - # in the rrule - function ChainRulesCore.rrule(::typeof(sinconj), x) - sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ) - return sin(x), sinconj_pullback - end - - rrule_test(sinconj, randn(ComplexF64), (randn(ComplexF64), randn(ComplexF64))) - test_scalar(sinconj, randn(ComplexF64)) - end - - @testset "binary: fst(x, y)" begin - fst(x, y) = x - ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx) - function ChainRulesCore.rrule(::typeof(fst), x, y) - function fst_pullback(Δx) - return (NoTangent(), Δx, ZeroTangent()) - end - return x, fst_pullback - end - @testset "frule_test" begin - frule_test(fst, (2, 4.0), (3, 5.0)) - frule_test(fst, (randn(4), randn(4)), (randn(4), randn(4))) - end - @testset "rrule_test" begin - rrule_test(fst, rand(), (2.0, 4.0), (3.0, 5.0)) - rrule_test(fst, randn(4), (randn(4), randn(4)), (randn(4), randn(4))) - end - end - - @testset "single input, multiple output" begin - simo(x) = (x, 2x) - function ChainRulesCore.rrule(simo, x) - simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b) - return simo(x), simo_pullback - end - function ChainRulesCore.frule((_, ẋ), simo, x) - y = simo(x) - return y, Tangent{typeof(y)}(ẋ, 2ẋ) - end - - @testset "frule_test" begin - frule_test(simo, (randn(), randn())) # on scalar - frule_test(simo, (randn(4), randn(4))) # on array - end - @testset "rrule_test" begin - # note: we are pulling back tuples (could use Tangents here instead) - rrule_test(simo, (randn(), rand()), (randn(), randn())) # on scalar - rrule_test(simo, (randn(4), rand(4)), (randn(4), randn(4))) # on array - end - end - - @testset "tuple input: first" begin - ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx)) - function ChainRulesCore.rrule(::typeof(first), x::Tuple) - function first_pullback(Δx) - return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x) - 1)...)) - end - return first(x), first_pullback - end - - CTuple{N} = Tangent{NTuple{N,Float64}} # shorter for testing - @testset "frule_test" begin - frule_test(first, ((2.0, 3.0), CTuple{2}(4.0, 5.0))) - frule_test(first, (Tuple(randn(4)), CTuple{4}(randn(4)...))) - end - @testset "rrule_test" begin - rrule_test(first, 2.0, ((2.0, 3.0), CTuple{2}(4.0, 5.0)); check_inferred=false) - rrule_test( - first, - randn(), - (Tuple(randn(4)), CTuple{4}(randn(4)...)); - check_inferred=false, - ) - end - end - - @testset "tuple output (backing type of Tangent =/= natural differential)" begin - tuple_out(x) = return (x, 1.0) # i.e. (x, 1.0) and not (x, x) - function ChainRulesCore.frule((_, dx), ::typeof(tuple_out), x) - Ω = tuple_out(x) - ∂Ω = Tangent{typeof(Ω)}(dx, ZeroTangent()) - return Ω, ∂Ω - end - frule_test(tuple_out, (2.0, 1)) - end - - @testset "ignoring arguments" begin - fsymtest(x, s::Symbol) = x - ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx) - function ChainRulesCore.rrule(::typeof(fsymtest), x, s) - function fsymtest_pullback(Δx) - return NoTangent(), Δx, NoTangent() - end - return x, fsymtest_pullback - end - - @testset "frule_test" begin - frule_test(fsymtest, (randn(), randn()), (:x, nothing)) - test_frule(fsymtest, 2.5, :x ⊢ nothing) - end - - @testset "rrule_test" begin - rrule_test(fsymtest, randn(), (randn(), randn()), (:x, nothing)) - test_rrule(fsymtest, 2.5, :x ⊢ nothing) - end - end - - @testset "unary with kwargs: futestkws(x; err)" begin - function ChainRulesCore.frule((_, ẋ), ::typeof(futestkws), x; err=true) - return futestkws(x; err=err), ẋ - end - function ChainRulesCore.rrule(::typeof(futestkws), x; err=true) - function futestkws_pullback(Δx) - return (NoTangent(), Δx) - end - return futestkws(x; err=err), futestkws_pullback - end - - # we defined these functions at top of file to throw errors unless we pass `err=false` - @test_throws ErrorException futestkws(randn()) - @test errors(() -> test_scalar(futestkws, randn()), "futestkws_err") - @test_throws ErrorException frule((nothing, randn()), futestkws, randn()) - @test_throws ErrorException rrule(futestkws, randn()) - - @test_throws ErrorException futestkws(randn(4)) - @test_throws ErrorException frule((nothing, randn(4)), futestkws, randn(4)) - @test_throws ErrorException rrule(futestkws, randn(4)) - - @testset "scalar_test" begin - test_scalar(futestkws, randn(); fkwargs=(; err=false)) - end - @testset "frule_test" begin - frule_test(futestkws, (randn(), randn()); fkwargs=(; err=false)) - frule_test(futestkws, (randn(4), randn(4)); fkwargs=(; err=false)) - end - @testset "rrule_test" begin - rrule_test(futestkws, randn(), (randn(), randn()); fkwargs=(; err=false)) - rrule_test(futestkws, randn(4), (randn(4), randn(4)); fkwargs=(; err=false)) - end - end - - @testset "binary with kwargs: fbtestkws(x, y; err)" begin - function ChainRulesCore.frule((_, ẋ, _), ::typeof(fbtestkws), x, y; err=true) - return fbtestkws(x, y; err=err), ẋ - end - function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err=true) - function fbtestkws_pullback(Δx) - return (NoTangent(), Δx, ZeroTangent()) - end - return fbtestkws(x, y; err=err), fbtestkws_pullback - end - - # we defined these functions at top of file to throw errors unless we pass `err=false` - @test_throws ErrorException fbtestkws(randn(), randn()) - @test_throws ErrorException frule( - (nothing, randn(), nothing), fbtestkws, randn(), randn() - ) - @test_throws ErrorException rrule(fbtestkws, randn(), randn()) - - @test_throws ErrorException fbtestkws(randn(4), randn(4)) - @test_throws ErrorException frule( - (nothing, randn(4), nothing), fbtestkws, randn(4), randn(4) - ) - @test_throws ErrorException rrule(fbtestkws, randn(4), randn(4)) - - @testset "frule_test" begin - frule_test( - fbtestkws, (randn(), randn()), (randn(), randn()); fkwargs=(; err=false) - ) - frule_test( - fbtestkws, (randn(4), randn(4)), (randn(4), randn(4)); fkwargs=(; err=false) - ) - end - @testset "rrule_test" begin - rrule_test( - fbtestkws, - randn(), - (randn(), randn()), - (randn(), randn()); - fkwargs=(; err=false), - ) - rrule_test( - fbtestkws, - randn(4), - (randn(4), randn(4)), - (randn(4), randn(4)); - fkwargs=(; err=false), - ) - end - end - - @testset "check_equal" begin - @test check_equal == test_approx - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 10f73c6..5bd9f7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,4 @@ ChainRulesTestUtils.TEST_INFERRED[] = true include("testers.jl") include("data_generation.jl") include("rand_tangent.jl") - - include("deprecated.jl") end diff --git a/test/testers.jl b/test/testers.jl index 0d87621..2c9bfce 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -33,7 +33,7 @@ Base.iterate(f::Foo, state) = iterate(f.a, state) function ChainRulesCore.rrule(::Type{Foo}, a) foo = Foo(a) function Foo_pullback(Δfoo) - return NoTangent(), Δfoo.a + return NoTangent(), unthunk(Δfoo).a end return foo, Foo_pullback end @@ -205,7 +205,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end @testset "check not inferred in pullback" begin function ChainRulesCore.rrule(::typeof(f_noninferrable_pullback), x) function f_noninferrable_pullback_pullback(Δy) - return (NoTangent(), x > 0 ? Float64(Δy) : Float32(Δy)) + return (NoTangent(), (x > 0 ? Float64 : Float32)(unthunk(Δy))) end return x, f_noninferrable_pullback_pullback end @@ -219,7 +219,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end @testset "check not inferred in thunk" begin function ChainRulesCore.rrule(::typeof(f_noninferrable_thunk), x, y) function f_noninferrable_thunk_pullback(Δz) - ∂x = @thunk(x > 0 ? Float64(Δz) : Float32(Δz)) + ∂x = @thunk(x > 0 ? Float64(unthunk(Δz)) : Float32(unthunk(Δz))) return (NoTangent(), ∂x, Δz) end return x + y, f_noninferrable_thunk_pullback @@ -233,10 +233,13 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end @testset "check non-inferrable primal still passes if pullback inferrable" begin function ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable_pullback_only), x) - return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx)) + T = x > 0 ? Float64 : Float32 + return T(x), T(Δx) end function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x) - f_inferrable_pullback_only_pullback(Δy) = (NoTangent(), oftype(x, Δy)) + function f_inferrable_pullback_only_pullback(Δy) + return NoTangent(), oftype(x, unthunk(Δy)) + end return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback end test_frule(f_inferrable_pullback_only, 2.0; check_inferred=true) @@ -580,45 +583,6 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end end end - @testset "tangent_transforms frule" begin - others_work(x) = 2x - function ChainRulesCore.frule((Δd, Δx), ::typeof(others_work), x) - return others_work(x), 2Δx - end - - others_nowork(x) = 2x - function ChainRulesCore.frule((Δd, Δx), ::typeof(others_nowork), x) - return others_nowork(x), error("nope") - end - - test_frule(others_work, rand(); tangent_transforms=[identity, x -> @thunk(x)]) - @test errors("nope") do - test_frule(others_nowork, 2.3; tangent_transforms=[x -> @thunk(x)]) - end - end - - @testset "tangent_transforms rrule" begin - others_work(x) = 2x - function ChainRulesCore.rrule(::typeof(others_work), x) - y = others_work(x) - others_work_pullback(ȳ) = return (NoTangent(), 2ȳ) - return y, others_work_pullback - end - - others_nowork(x) = [x, x] - function ChainRulesCore.rrule(::typeof(others_nowork), x) - y = others_nowork(x) - others_nowork_pullback(ȳ) = return (NoTangent(), error("nope")) - return y, others_nowork_pullback - end - - test_rrule(others_work, 2.3; tangent_transforms=[_ -> ZeroTangent()]) - test_rrule(others_work, 2.3; tangent_transforms=[x -> @thunk(x)]) - - @test errors("nope") do - test_rrule(others_nowork, 2.3; tangent_transforms=[x -> @thunk(x)]) - end - end @testset "Tuple primal that is not equal to differential backing" begin # https://github.com/JuliaMath/SpecialFunctions.jl/issues/288 @@ -634,6 +598,35 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end test_rrule(rev_trouble, (3, 3.0) ⊢ Tangent{Tuple{Int,Float64}}(ZeroTangent(), 1.0)) end + @testset "check_thunked_output_tangent" begin + @testset "no method for thunk" begin + does_not_accept_thunk_id(x) = x + function ChainRulesCore.rrule(::typeof(does_not_accept_thunk_id), x) + does_not_accept_thunk_id_pullback(ȳ::AbstractArray) = (NoTangent() ,ȳ) + return does_not_accept_thunk_id(x), does_not_accept_thunk_id_pullback + end + + test_rrule( + does_not_accept_thunk_id, [1.0, 2.0]; check_thunked_output_tangent=false + ) + @test errors(r"MethodError.*Thunk") do + test_rrule(does_not_accept_thunk_id, [1.0, 2.0]) + end + end + + @testset "Thunk wrong" begin + bad_thunk_id(x) = x + function ChainRulesCore.rrule(::typeof(bad_thunk_id), x) + bad_thunk_id_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ) + bad_thunk_id_pullback(ȳ::AbstractThunk) = (NoTangent(), 2 * ȳ) + return bad_thunk_id(x), bad_thunk_id_pullback + end + + test_rrule(bad_thunk_id, [1.0, 2.0]; check_thunked_output_tangent=false) + @test fails(()->test_rrule(bad_thunk_id, [1.0, 2.0])) + end + end + @testset "error message about incorrectly using ZeroTangent()" begin foo(a, i) = a[i] function ChainRulesCore.rrule(::typeof(foo), a, i)