Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
947382a
initial empty commit
willtebbutt Jan 17, 2018
93b6058
Implement rand_tangent and difference (#91)
willtebbutt Jul 11, 2020
786983c
make rand_tangent of struct with no pertable fields return DoesNotExist
oxinabox Apr 29, 2021
22e1fbd
copy `rand_tangent(::BigFloat)` from `ChainRulesTestUtils` (#155)
mzgubic May 4, 2021
094a78b
Merge pull request #158 from JuliaDiff/ox/simrandtan
oxinabox May 7, 2021
ebff52d
rename differentials (#162)
mzgubic May 26, 2021
89133ec
=Make rand_tangent on adjoint an transpose return natural
oxinabox May 27, 2021
f4a97a7
typo
oxinabox May 27, 2021
6a09c8a
braces
oxinabox May 27, 2021
f71f77b
stop testing + behavour that is not defines in this package
oxinabox May 28, 2021
8dc43d1
Merge pull request #165 from JuliaDiff/ox/covectors
oxinabox May 28, 2021
e776799
Replace `NO_FIELDS` by `NoTangent()` (#163)
mzgubic Jun 1, 2021
460fc3d
make rand_tangent give only nice numbers
oxinabox May 28, 2021
a7a416c
fix rand_tangent(::BigFloat)
oxinabox May 28, 2021
00596e3
Explain why scaling randn
oxinabox Jun 1, 2021
12033c6
only test printing length on 1.6 (also relax length for bigfloats)
oxinabox Jun 1, 2021
231e685
Relax big further
oxinabox Jun 1, 2021
76619c1
Merge pull request #168 from JuliaDiff/ox/nicerand
oxinabox Jun 1, 2021
bc71ac1
Add `rand_tangent` for types (#172)
mzgubic Jun 8, 2021
61a83ed
Revert "Add `rand_tangent` for types (#172)"
Jun 9, 2021
96b1abd
Revert "Revert "Add `rand_tangent` for types (#172)""
Jun 9, 2021
b8d5573
Add rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T}
Jul 1, 2021
89b202c
Merge pull request #180 from JuliaDiff/ar/rand_tangent_zeroarrays
AlexRobson Jul 1, 2021
5761090
Merge branch 'master' of ../ChainRulesTestUtils.jl
oxinabox Jul 21, 2021
b4bedeb
include rand_tangent
oxinabox Jul 21, 2021
30f4c03
Don't import rand_tangent from FiniteDifferences
oxinabox Jul 21, 2021
fc3fb1e
Fix up tests that don't work because of move
oxinabox Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/ChainRulesTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@ using LinearAlgebra
using Random
using Test

import FiniteDifferences: rand_tangent

export TestIterator
export test_approx, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
export ⊢
export ⊢, rand_tangent
export @maybe_inferred

__init__() = init_test_inferred_setting!()

include("global_config.jl")

include("rand_tangent.jl")
include("generate_tangent.jl")
include("data_generation.jl")
include("iterator.jl")
Expand Down
63 changes: 63 additions & 0 deletions src/rand_tangent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
rand_tangent([rng::AbstractRNG,] x)

Returns a arbitary tangent vector _appropriate_ for the primal value `x`.
Note that despite the name, no promises on the statistical randomness are made.
Rather it is an arbitary value, that is generated using the `rng`.
"""
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)

rand_tangent(rng::AbstractRNG, x::Symbol) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractChar) = NoTangent()
rand_tangent(rng::AbstractRNG, x::AbstractString) = NoTangent()

rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()

# Try and make nice numbers with short decimal representations for good error messages
# while also not biasing the sample space too much
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Number}
# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
return round(9 * randn(rng, T), sigdigits=5, base=2)
end
rand_tangent(rng::AbstractRNG, x::Float64) = rand(rng, -9:0.01:9)
function rand_tangent(rng::AbstractRNG, x::ComplexF64)
return ComplexF64(rand(rng, -9:0.1:9), rand(rng, -9:0.1:9))
end

#BigFloat/MPFR is finicky about short numbers, this doesn't always work as well as it should

# multiply by 9 to give a bigger range of values tested: no so tightly clustered around 0.
rand_tangent(rng::AbstractRNG, ::BigFloat) = round(big(9 * randn(rng)), sigdigits=5, base=2)

rand_tangent(rng::AbstractRNG, x::StridedArray{T, 0}) where {T} = fill(rand_tangent(x[1]))
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x)))
rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x)))

function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
return Tangent{T}(rand_tangent.(Ref(rng), x)...)
end

function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
return Tangent{T}(; map(x -> rand_tangent(rng, x), xs)...)
end

function rand_tangent(rng::AbstractRNG, x::T) where {T}
if !isstructtype(T)
throw(ArgumentError("Non-struct types are not supported by this fallback."))
end

field_names = fieldnames(T)
tangents = map(field_names) do field_name
rand_tangent(rng, getfield(x, field_name))
end
if all(tangent isa NoTangent for tangent in tangents)
# if none of my fields can be perturbed then I can't be perturbed
return NoTangent()
else
Tangent{T}(; NamedTuple{field_names}(tangents)...)
end
end

rand_tangent(rng::AbstractRNG, ::Type) = NoTangent()
rand_tangent(rng::AbstractRNG, ::Module) = NoTangent()
2 changes: 1 addition & 1 deletion test/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
@testset "rand_tangent" begin
data = randn(2, 3, 4)
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
∂iter = FiniteDifferences.rand_tangent(iter)
∂iter = rand_tangent(iter)
@test ∂iter isa typeof(iter)
@test size(∂iter.data) == size(iter.data)
@test eltype(∂iter.data) === eltype(iter.data)
Expand Down
106 changes: 106 additions & 0 deletions test/rand_tangent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Test struct for `rand_tangent` and `difference`.
struct Bar
a::Float64
b::Int
c::Any
end
@testset "rand_tangent" begin
rng = MersenneTwister(123456)

@testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [

# Things without sensible tangents.
("hi", NoTangent),
('a', NoTangent),
(:a, NoTangent),
(true, NoTangent),
(4, NoTangent),
(FiniteDifferences, NoTangent), # Module object
# Types (not instances of type)
(Bar, NoTangent),
(Union{Int, Bar}, NoTangent),
(Union{Int, Bar}, NoTangent),
(Vector, NoTangent),
(Vector{Float64}, NoTangent),
(Integer, NoTangent),
(Type{<:Real}, NoTangent),

# Numbers.
(5.0, Float64),
(5.0 + 0.4im, Complex{Float64}),
(big(5.0), BigFloat),

# StridedArrays.
(fill(randn(Float32)), Array{Float32, 0}),
(fill(randn(Float64)), Array{Float64, 0}),
(randn(Float32, 3), Vector{Float32}),
(randn(Complex{Float64}, 2), Vector{Complex{Float64}}),
(randn(5, 4), Matrix{Float64}),
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
([randn(5, 4), 4.0], Vector{Any}),

# Wrapper Arrays
(randn(5, 4)', Adjoint{Float64, Matrix{Float64}}),
(transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}),


# Tuples.
((4.0, ), Tangent{Tuple{Float64}}),
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),

# NamedTuples.
((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}),
((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),

# structs.
(Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}),
(Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}),
(sin, NoTangent),
# all fields NoTangent implies NoTangent
(Pair(:a, "b"), NoTangent),
(1:10, NoTangent),
(1:2:10, NoTangent),

# LinearAlgebra types (also just structs).
(
UpperTriangular(randn(3, 3)),
Tangent{UpperTriangular{Float64, Matrix{Float64}}},
),
(
Diagonal(randn(2)),
Tangent{Diagonal{Float64, Vector{Float64}}},
),
(
Symmetric(randn(2, 2)),
Tangent{Symmetric{Float64, Matrix{Float64}}},
),
(
Hermitian(randn(ComplexF64, 1, 1)),
Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}},
),
]
@test rand_tangent(rng, x) isa T_tangent
@test rand_tangent(x) isa T_tangent
end

@testset "erroring cases" begin
# Ensure struct fallback errors for non-struct types.
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
end

@testset "compsition of addition" begin
x = Bar(1.5, 2, Bar(1.1, 3, [1.7, 1.4, 0.9]))
@test x + rand_tangent(x) isa typeof(x)
@test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x)
end

# Julia 1.6 changed to using Ryu printing algorithm and seems better at printing short
VERSION > v"1.6" && @testset "niceness of printing" begin
for i in 1:50
@test length(string(rand_tangent(1.0))) <= 6
@test length(string(rand_tangent(1.0 + 1.0im))) <= 12
@test length(string(rand_tangent(1f0))) <= 12
@test length(string(rand_tangent(big"1.0"))) <= 12
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ChainRulesTestUtils.TEST_INFERRED[] = true
include("check_result.jl")
include("testers.jl")
include("data_generation.jl")
include("rand_tangent.jl")

include("deprecated.jl")
end