From 439d036956b5ab4da96d041d0556a3d0a495c980 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 28 May 2021 21:36:43 +0100 Subject: [PATCH 1/2] Avoid nesting testsets in test_rule --- src/ChainRulesTestUtils.jl | 3 ++ src/check_result.jl | 89 +++++++++++++++++++++++--------------- src/output_control.jl | 59 +++++++++++++++++++++++++ src/testers.jl | 15 +++---- test/testers.jl | 4 +- 5 files changed, 125 insertions(+), 45 deletions(-) create mode 100644 src/output_control.jl diff --git a/src/ChainRulesTestUtils.jl b/src/ChainRulesTestUtils.jl index ea32449c..9979e864 100644 --- a/src/ChainRulesTestUtils.jl +++ b/src/ChainRulesTestUtils.jl @@ -17,9 +17,12 @@ export TestIterator export check_equal, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix export ⊢ + include("generate_tangent.jl") include("data_generation.jl") include("iterator.jl") + +include("output_control.jl") include("check_result.jl") include("finite_difference_calls.jl") diff --git a/src/check_result.jl b/src/check_result.jl index 3008dd95..2a9a54f5 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -4,50 +4,56 @@ # Note that this must work well both on Differential types and Primal types """ - check_equal(actual, expected; kwargs...) + check_equal(actual, expected, [msg]; kwargs...) `@test`'s that `actual ≈ expected`, but breaks up data such that human readable results are shown on failures. Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc. + +If provided `msg` is printed on a failure. Often additional items are appended to `msg` to +give bread-crumbs into nested structures. + All keyword arguments are passed to `isapprox`. """ function check_equal( actual::Union{AbstractArray{<:Number},Number}, - expected::Union{AbstractArray{<:Number},Number}; + expected::Union{AbstractArray{<:Number},Number}, + msg="", + ; kwargs..., ) - @test isapprox(actual, expected; kwargs...) + @test_msg msg isapprox(actual, expected; kwargs...) end for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk)) - @eval function check_equal(actual::$T1, expected::$T2; kwargs...) - return check_equal(unthunk(actual), unthunk(expected); kwargs...) + @eval function check_equal(actual::$T1, expected::$T2, msg=""; kwargs...) + return check_equal(unthunk(actual), unthunk(expected), msg; kwargs...) end end -check_equal(::ZeroTangent, x; kwargs...) = check_equal(zero(x), x; kwargs...) -check_equal(x, ::ZeroTangent; kwargs...) = check_equal(x, zero(x); kwargs...) -check_equal(x::ZeroTangent, y::ZeroTangent; kwargs...) = @test true +check_equal(::ZeroTangent, x, msg=""; kwargs...) = check_equal(zero(x), x, msg; kwargs...) +check_equal(x, ::ZeroTangent, msg=""; kwargs...) = check_equal(x, zero(x), msg; kwargs...) +check_equal(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true # remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 -check_equal(x::NoTangent, y::Nothing; kwargs...) = @test true -check_equal(x::Nothing, y::NoTangent; kwargs...) = @test true +check_equal(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true +check_equal(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true # Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally # not yet been implemented # `@test_broken x == y` yields more descriptive messages than `@test_broken false` -check_equal(x::ChainRulesCore.NotImplemented, y; kwargs...) = @test_broken x == y -check_equal(x, y::ChainRulesCore.NotImplemented; kwargs...) = @test_broken x == y +check_equal(x::ChainRulesCore.NotImplemented, y, msg=""; kwargs...) = @test_broken x == y +check_equal(x, y::ChainRulesCore.NotImplemented, msg=""; kwargs...) = @test_broken x == y # In this case we check for equality (messages etc. have to be equal) function check_equal( - x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented; kwargs... + x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented, msg=""; kwargs... ) - return @test x == y + return @test_msg msg x == y end """ _can_pass_early(actual, expected; kwargs...) -Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper; +Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper and can just report `check_equal` as passing. If either `==` or `≈` return true then so does this. @@ -64,30 +70,34 @@ function _can_pass_early(actual, expected; kwargs...) return false end -function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...) +function check_equal(actual::AbstractArray, expected::AbstractArray, msg=""; kwargs...) if _can_pass_early(actual, expected) @test true else - @test eachindex(actual) == eachindex(expected) - @testset "$(typeof(actual))[$ii]" for ii in eachindex(actual) - check_equal(actual[ii], expected[ii]; kwargs...) + @test_msg "$msg: indexes must match" eachindex(actual) == eachindex(expected) + for ii in eachindex(actual) + new_msg = "$msg $(typeof(actual))[$ii]" + check_equal(actual[ii], expected[ii], new_msg; kwargs...) end end end -function check_equal(actual::Tangent{P}, expected::Tangent{P}; kwargs...) where {P} +function check_equal(actual::Tangent{P}, expected::Tangent{P}, msg=""; kwargs...) where {P} if _can_pass_early(actual, expected) @test true else all_keys = union(keys(actual), keys(expected)) - @testset "$P.$ii" for ii in all_keys - check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...) + for ii in all_keys + new_msg = "$msg $P.$ii" + check_equal( + getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs... + ) end end end function check_equal( - ::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}; kwargs... + ::Tangent{ActualPrimal}, expected::Tangent{ExpectedPrimal}, msg=""; kwargs... ) where {ActualPrimal,ExpectedPrimal} # this will certainly fail as we have another dispatch for that, but this will give as # good error message @@ -95,7 +105,7 @@ function check_equal( end # Some structual differential and a natural differential -function check_equal(actual::Tangent{P,T}, expected; kwargs...) where {T,P} +function check_equal(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T,P} if _can_pass_early(actual, expected) @test true else @@ -103,21 +113,28 @@ function check_equal(actual::Tangent{P,T}, expected; kwargs...) where {T,P} # We are only checking the properties that are in the Tangent # the natural differential is allowed to have other properties that we ignore - @testset "$P.$ii" for ii in propertynames(actual) - check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...) + for ii in propertynames(actual) + new_msg = "$msg $P.$ii" + check_equal( + getproperty(actual, ii), getproperty(expected, ii), new_msg; kwargs... + ) end end end -check_equal(x, y::Tangent; kwargs...) = check_equal(y, x; kwargs...) +check_equal(x, y::Tangent, msg=""; kwargs...) = check_equal(y, x, msg; kwargs...) # This catches comparisons of Tangents and Tuples/NamedTuple -# and gives an error message complaining about that +# and gives an error message complaining about that. the `@test` will definately fail const LegacyZygoteCompTypes = Union{Tuple,NamedTuple} -check_equal(::C, ::T; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test C === T -check_equal(::T, ::C; kwargs...) where {C<:Tangent,T<:LegacyZygoteCompTypes} = @test T === C +function check_equal(x::Tangent, y::LegacyZygoteCompTypes, msg=""; kwargs...) + @test_msg "$msg: for structural differentials use `Tangent`" typeof(x) === typeof(y) +end +function check_equal(x::LegacyZygoteCompTypes, y::Tangent, msg=""; kwargs...) + return check_equal(y, x, msg; kwargs...) +end # Generic fallback, probably a tuple or something -function check_equal(actual::A, expected::E; kwargs...) where {A,E} +function check_equal(actual::A, expected::E, msg=""; kwargs...) where {A,E} if _can_pass_early(actual, expected) @test true else @@ -130,6 +147,8 @@ function check_equal(actual::A, expected::E; kwargs...) where {A,E} end end +########################################################################################### + """ _check_add!!_behaviour(acc, val) @@ -146,11 +165,11 @@ function _check_add!!_behaviour(acc, val; kwargs...) # e.g. if it is immutable. We do test the `add!!` return value. # That is what people should rely on. The mutation is just to save allocations. acc_mutated = deepcopy(acc) # prevent this test changing others - return check_equal(add!!(acc_mutated, val), acc + val; kwargs...) + return check_equal(add!!(acc_mutated, val), acc + val, "in add!!"; kwargs...) end -# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally -# not yet been implemented +# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has +# intentionally not yet been implemented # `@test_broken x == y` yields more descriptive messages than `@test_broken false` function _check_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...) return @test_broken acc_mutated == acc @@ -158,7 +177,7 @@ end function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...) return @test_broken acc_mutated == acc end -# In this case we check for equality (messages etc. have to be equal) +# In this case we check for equality (not implemented messages etc. have to be equal) function _check_add!!_behaviour( acc_mutated::ChainRulesCore.NotImplemented, acc::ChainRulesCore.NotImplemented; diff --git a/src/output_control.jl b/src/output_control.jl new file mode 100644 index 00000000..400bde7e --- /dev/null +++ b/src/output_control.jl @@ -0,0 +1,59 @@ +# Test.get_test_result generates code that uses the following so we must import them +using Test: Returned, Threw, eval_test + +"A cunning hack to carry extra message along with the original expression in a test" +struct ExprAndMsg + ex + msg +end + +""" + @test_msg msg condion kws... + +This is per `Test.@test condion kws...` except that if it fails it also prints the `msg`. +If `msg==""` then this is just like `@test`, nothing is printed + +### Examles +```julia +julia> @test_msg "It is required that the total is under 10" sum(1:1000) < 10; +Test Failed at REPL[1]:1 + Expression: sum(1:1000) < 10 + Problem: It is required that the total is under 10 + Evaluated: 500500 < 10 +ERROR: There was an error during testing + + +julia> @test_msg "It is required that the total is under 10" error("not working at all"); +Error During Test at REPL[2]:1 + Test threw exception + Expression: error("not working at all") + Problem: It is required that the total is under 10 + "not working at all" + Stacktrace: + +julia> a = ""; + +julia> @test_msg a sum(1:1000) < 10; + Test Failed at REPL[153]:1 + Expression: sum(1:1000) < 10 + Evaluated: 500500 < 10 + ERROR: There was an error during testing +``` +""" +macro test_msg(msg, ex, kws...) + Test.test_expr!("@test_msg msg", ex, kws...) + + result = Test.get_test_result(ex, __source__) + return :(Test.do_test($result, $ExprAndMsg($(string(ex)), $(esc(msg))))) +end + +function Base.print(io::IO, x::ExprAndMsg) + print(io, x.ex) + !isempty(x.msg) && print(io, "\n Problem: ", x.msg) +end + + +### helpers for printing in log messages etc +_string_typeof(x) = string(typeof(x)) +_string_typeof(xs::Tuple) = join(_string_typeof.(xs), ",") +_string_typeof(x::PrimalAndTangent) = _string_typeof(primal(x)) # only show primal diff --git a/src/testers.jl b/src/testers.jl index 225c734e..a81cf203 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -99,7 +99,7 @@ function test_frule( # To simplify some of the calls we make later lets group the kwargs for reuse isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) - @testset "test_frule: $f on $(join(typeof.(inputs), ","))" begin + @testset "test_frule: $f on $(_string_typeof(inputs))" begin _ensure_not_running_on_functor(f, "test_frule") xẋs = auto_primal_and_tangent.(inputs) @@ -167,7 +167,7 @@ function test_rrule( # To simplify some of the calls we make later lets group the kwargs for reuse isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) - @testset "test_rrule: $f on $(join(typeof.(inputs), ","))" begin + @testset "test_rrule: $f on $(_string_typeof(inputs))" begin _ensure_not_running_on_functor(f, "test_rrule") # Check correctness of evaluation. @@ -226,12 +226,11 @@ function test_rrule( end function check_thunking_is_appropriate(x̄s) - @testset "Don't thunk only non_zero argument" begin - num_zeros = count(x -> x isa AbstractZero, x̄s) - num_thunks = count(x -> x isa Thunk, x̄s) - if num_zeros + num_thunks == length(x̄s) - @test num_thunks !== 1 - end + num_zeros = count(x -> x isa AbstractZero, x̄s) + num_thunks = count(x -> x isa Thunk, x̄s) + if num_zeros + num_thunks == length(x̄s) + # num_thunks can be either 0, or greater than 1. + @test_msg "Should not thunk only non_zero argument" num_thunks != 1 end end diff --git a/test/testers.jl b/test/testers.jl index 611f1881..ac9a7ad1 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -1,4 +1,5 @@ -# For some reason if these aren't defined here, then they are interpreted as closures +# Defining test functions here as if they are defined where used it is too easy to +# mistakenly create closures over variables that only share names by coincidence. futestkws(x; err=true) = err ? error("futestkws_err") : x fbtestkws(x, y; err=true) = err ? error("fbtestkws_err") : x @@ -268,7 +269,6 @@ end return first(x), first_pullback end - #CTuple{N} = Tangent{NTuple{N, Float64}} # shorter for testing @testset "test_frule" begin test_frule(first, (2.0, 3.0)) test_frule(first, Tuple(randn(4))) From 525ffa4340b594ef13b25bdf0e0c94442d0b219a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 1 Jun 2021 18:19:26 +0100 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Miha Zgubic --- src/check_result.jl | 7 +++---- src/output_control.jl | 3 +++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/check_result.jl b/src/check_result.jl index 2a9a54f5..c4bb12b8 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -18,8 +18,7 @@ All keyword arguments are passed to `isapprox`. function check_equal( actual::Union{AbstractArray{<:Number},Number}, expected::Union{AbstractArray{<:Number},Number}, - msg="", - ; + msg=""; kwargs..., ) @test_msg msg isapprox(actual, expected; kwargs...) @@ -74,7 +73,7 @@ function check_equal(actual::AbstractArray, expected::AbstractArray, msg=""; kwa if _can_pass_early(actual, expected) @test true else - @test_msg "$msg: indexes must match" eachindex(actual) == eachindex(expected) + @test_msg "$msg: indices must match" eachindex(actual) == eachindex(expected) for ii in eachindex(actual) new_msg = "$msg $(typeof(actual))[$ii]" check_equal(actual[ii], expected[ii], new_msg; kwargs...) @@ -124,7 +123,7 @@ end check_equal(x, y::Tangent, msg=""; kwargs...) = check_equal(y, x, msg; kwargs...) # This catches comparisons of Tangents and Tuples/NamedTuple -# and gives an error message complaining about that. the `@test` will definately fail +# and gives an error message complaining about that. the `@test` will definitely fail const LegacyZygoteCompTypes = Union{Tuple,NamedTuple} function check_equal(x::Tangent, y::LegacyZygoteCompTypes, msg=""; kwargs...) @test_msg "$msg: for structural differentials use `Tangent`" typeof(x) === typeof(y) diff --git a/src/output_control.jl b/src/output_control.jl index 400bde7e..d9eba570 100644 --- a/src/output_control.jl +++ b/src/output_control.jl @@ -41,6 +41,9 @@ julia> @test_msg a sum(1:1000) < 10; ``` """ macro test_msg(msg, ex, kws...) + # This code is basically a evil hack that accesses the internals of the Test stdlib. + # Code below is based on the `@test` macro definition as it was in Julia 1.6. + # https://github.com/JuliaLang/julia/blob/v1.6.1/stdlib/Test/src/Test.jl#L371-L376 Test.test_expr!("@test_msg msg", ex, kws...) result = Test.get_test_result(ex, __source__)