diff --git a/src/ChainRulesTestUtils.jl b/src/ChainRulesTestUtils.jl index 1dd306c..03b794e 100644 --- a/src/ChainRulesTestUtils.jl +++ b/src/ChainRulesTestUtils.jl @@ -9,17 +9,16 @@ using LinearAlgebra using Random using Test -import FiniteDifferences: rand_tangent - export TestIterator export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix -export ⊢ +export ⊢, rand_tangent export @maybe_inferred __init__() = init_test_inferred_setting!() include("global_config.jl") +include("rand_tangent.jl") include("generate_tangent.jl") include("data_generation.jl") include("iterator.jl") diff --git a/src/rand_tangent.jl b/src/rand_tangent.jl new file mode 100644 index 0000000..adb0ac8 --- /dev/null +++ b/src/rand_tangent.jl @@ -0,0 +1,63 @@ +""" + rand_tangent([rng::AbstractRNG,] x) + +Returns a arbitary tangent vector _appropriate_ for the primal value `x`. +Note that despite the name, no promises on the statistical randomness are made. +Rather it is an arbitary value, that is generated using the `rng`. +""" +rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x) + +rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent() +rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent() +rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent() + +rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent() + +# Try and make nice numbers with short decimal representations for good error messages +# while also not biasing the sample space too much +function rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} + # multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0. + return round(9 * randn(rng, T), sigdigits=5, base=2) +end +rand_tangent(rng::AbstractRNG, x::Float64) = rand(rng, -9:0.01:9) +function rand_tangent(rng::AbstractRNG, x::ComplexF64) + return ComplexF64(rand(rng, -9:0.1:9), rand(rng, -9:0.1:9)) +end + +#BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should + +# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0. +rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), sigdigits=5, base=2) + +rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T} = fill(rand_tangent(x[1])) +rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x) +rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x))) +rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x))) + +function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple} + return Tangent{T}(rand_tangent.(Ref(rng), x)...) +end + +function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple} + return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...) +end + +function rand_tangent(rng::AbstractRNG, x::T) where {T} + if !isstructtype(T) + throw(ArgumentError("Non-struct types are not supported by this fallback.")) + end + + field_names = fieldnames(T) + tangents = map(field_names) do field_name + rand_tangent(rng, getfield(x, field_name)) + end + if all(tangent isa NoTangent for tangent in tangents) + # if none of my fields can be perturbed then I can't be perturbed + return NoTangent() + else + Tangent{T}(; NamedTuple{field_names}(tangents)...) + end +end + +rand_tangent(rng::AbstractRNG, ::Type) = NoTangent() +rand_tangent(rng::AbstractRNG, ::Module) = NoTangent() diff --git a/test/iterator.jl b/test/iterator.jl index bbd9f19..df50a9b 100644 --- a/test/iterator.jl +++ b/test/iterator.jl @@ -88,7 +88,7 @@ @testset "rand_tangent" begin data = randn(2, 3, 4) iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown()) - ∂iter = FiniteDifferences.rand_tangent(iter) + ∂iter = rand_tangent(iter) @test ∂iter isa typeof(iter) @test size(∂iter.data) == size(iter.data) @test eltype(∂iter.data) === eltype(iter.data) diff --git a/test/rand_tangent.jl b/test/rand_tangent.jl new file mode 100644 index 0000000..ed479a5 --- /dev/null +++ b/test/rand_tangent.jl @@ -0,0 +1,106 @@ +# Test struct for `rand_tangent` and `difference`. +struct Bar + a::Float64 + b::Int + c::Any + end +@testset "rand_tangent" begin + rng = MersenneTwister(123456) + + @testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [ + + # Things without sensible tangents. + ("hi", NoTangent), + ('a', NoTangent), + (:a, NoTangent), + (true, NoTangent), + (4, NoTangent), + (FiniteDifferences, NoTangent), # Module object + # Types (not instances of type) + (Bar, NoTangent), + (Union{Int, Bar}, NoTangent), + (Union{Int, Bar}, NoTangent), + (Vector, NoTangent), + (Vector{Float64}, NoTangent), + (Integer, NoTangent), + (Type{<:Real}, NoTangent), + + # Numbers. + (5.0, Float64), + (5.0 + 0.4im, Complex{Float64}), + (big(5.0), BigFloat), + + # StridedArrays. + (fill(randn(Float32)), Array{Float32, 0}), + (fill(randn(Float64)), Array{Float64, 0}), + (randn(Float32, 3), Vector{Float32}), + (randn(Complex{Float64}, 2), Vector{Complex{Float64}}), + (randn(5, 4), Matrix{Float64}), + (randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}), + ([randn(5, 4), 4.0], Vector{Any}), + + # Wrapper Arrays + (randn(5, 4)', Adjoint{Float64, Matrix{Float64}}), + (transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}), + + + # Tuples. + ((4.0, ), Tangent{Tuple{Float64}}), + ((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}), + + # NamedTuples. + ((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}), + ((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}), + + # structs. + (Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}), + (Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}), + (sin, NoTangent), + # all fields NoTangent implies NoTangent + (Pair(:a, "b"), NoTangent), + (1:10, NoTangent), + (1:2:10, NoTangent), + + # LinearAlgebra types (also just structs). + ( + UpperTriangular(randn(3, 3)), + Tangent{UpperTriangular{Float64, Matrix{Float64}}}, + ), + ( + Diagonal(randn(2)), + Tangent{Diagonal{Float64, Vector{Float64}}}, + ), + ( + Symmetric(randn(2, 2)), + Tangent{Symmetric{Float64, Matrix{Float64}}}, + ), + ( + Hermitian(randn(ComplexF64, 1, 1)), + Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}}, + ), + ] + @test rand_tangent(rng, x) isa T_tangent + @test rand_tangent(x) isa T_tangent + end + + @testset "erroring cases" begin + # Ensure struct fallback errors for non-struct types. + @test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0) + end + + @testset "compsition of addition" begin + x = Bar(1.5, 2, Bar(1.1, 3, [1.7, 1.4, 0.9])) + @test x + rand_tangent(x) isa typeof(x) + @test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x) + end + + # Julia 1.6 changed to using Ryu printing algorithm and seems better at printing short + VERSION > v"1.6" && @testset "niceness of printing" begin + for i in 1:50 + @test length(string(rand_tangent(1.0))) <= 6 + @test length(string(rand_tangent(1.0 + 1.0im))) <= 12 + @test length(string(rand_tangent(1f0))) <= 12 + @test length(string(rand_tangent(big"1.0"))) <= 12 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 845d696..10f73c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ ChainRulesTestUtils.TEST_INFERRED[] = true include("check_result.jl") include("testers.jl") include("data_generation.jl") + include("rand_tangent.jl") include("deprecated.jl") end