diff --git a/Project.toml b/Project.toml index 9092e20..dc435c4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.9" +version = "0.12.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] -ChainRulesCore = "0.9.44" +ChainRulesCore = "0.9.44, 0.10" Richardson = "1.2" StaticArrays = "0.12, 1.0" julia = "1" diff --git a/src/rand_tangent.jl b/src/rand_tangent.jl index cdca3aa..4ca2548 100644 --- a/src/rand_tangent.jl +++ b/src/rand_tangent.jl @@ -35,17 +35,13 @@ function rand_tangent(rng::AbstractRNG, x::T) where {T} end field_names = fieldnames(T) - if length(field_names) > 0 - tangents = map(field_names) do field_name - rand_tangent(rng, getfield(x, field_name)) - end - if all(tangent isa NoTangent for tangent in tangents) - # if none of my fields can be perturbed then I can't be perturbed - return NoTangent() - else - Tangent{T}(; NamedTuple{field_names}(tangents)...) - end + tangents = map(field_names) do field_name + rand_tangent(rng, getfield(x, field_name)) + end + if all(tangent isa NoTangent for tangent in tangents) + # if none of my fields can be perturbed then I can't be perturbed + return NoTangent() else - return NO_FIELDS + Tangent{T}(; NamedTuple{field_names}(tangents)...) end end diff --git a/test/rand_tangent.jl b/test/rand_tangent.jl index 876f91c..92a2d50 100644 --- a/test/rand_tangent.jl +++ b/test/rand_tangent.jl @@ -40,7 +40,7 @@ using FiniteDifferences: rand_tangent # structs. (Foo(5.0, 4, rand(rng, 3)), Tangent{Foo}), (Foo(4.0, 3, Foo(5.0, 2, 4)), Tangent{Foo}), - (sin, typeof(NO_FIELDS)), + (sin, NoTangent), # all fields NoTangent implies NoTangent (Pair(:a, "b"), NoTangent), (1:10, NoTangent),