Skip to content
25 changes: 15 additions & 10 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# This file is machine-generated - editing it directly is not advised

[[ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9"
version = "0.0.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

Expand All @@ -11,15 +16,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "0b0aa9d61456940511416b59a0e902c57b154956"
git-tree-sha1 = "f53ca8d41e4753c41cdafa6ec5f7ce914b34be54"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.12"
version = "0.10.13"

[[ChainRulesTestUtils]]
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
path = ".."
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.13"
version = "1.0.0-DEV"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand All @@ -46,20 +51,20 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.5"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "47f13b6305ab195edb73c86815962d84e31b0f48"
deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "95265abf7d7bf06dfdb8d58525a23ea5fb0bdeee"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.27.3"
version = "0.27.4"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FiniteDifferences]]
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
git-tree-sha1 = "12417e4754486a547d98d65293dc0fafdfcc0736"
git-tree-sha1 = "18761c465ef2e87d9091c0fefb61f70d532d4cc0"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.14"
version = "0.12.16"

[[IOCapture]]
deps = ["Logging", "Random"]
Expand Down Expand Up @@ -167,9 +172,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "a43a7b58a6e7dc933b2fa2e0ca653ccf8bb8fd0e"
git-tree-sha1 = "1b9a0f17ee0adde9e538227de093467348992397"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.6"
version = "1.2.7"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down
6 changes: 0 additions & 6 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,3 @@
Modules = [ChainRulesTestUtils]
Private = false
```


## Global Configuration
```@docs
ChainRulesTestUtils.enable_tangent_transform!
```
6 changes: 3 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly
```jldoctest ex
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 8 8
test_rrule: two2three on Float64,Float64 | 9 9

```

Expand Down Expand Up @@ -100,12 +100,12 @@ call.
```jldoctest ex
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 10 10
test_scalar: relu at 0.5 | 11 11
Comment on lines 102 to +103

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documenter (fix doctests)] reported by reviewdog 🐶

Suggested change
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 10 10
test_scalar: relu at 0.5 | 11 11



julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 10 10
test_scalar: relu at -0.5 | 11 11

```

Expand Down
2 changes: 0 additions & 2 deletions src/ChainRulesTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ include("check_result.jl")
include("rule_config.jl")
include("finite_difference_calls.jl")
include("testers.jl")

include("deprecated.jl")
end # module
2 changes: 1 addition & 1 deletion src/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ function test_approx(actual::A, expected::E, msg=""; kwargs...) where {A,E}
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
throw(MethodError, test_approx, (actual, expected))
end
test_approx(c_actual, c_expected; kwargs...)
test_approx(c_actual, c_expected, msg; kwargs...)
end
end

Expand Down
93 changes: 0 additions & 93 deletions src/deprecated.jl

This file was deleted.

24 changes: 1 addition & 23 deletions src/global_config.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,6 @@
const _fdm = central_fdm(5, 1; max_range=1e-2)
const TEST_INFERRED = Ref(true)
const TRANSFORMS_TO_ALT_TANGENTS = Function[] # e.g. [x -> @thunk(x), _ -> ZeroTangent(), x -> rebasis(x)]

"""
enable_tangent_transform!(Thunk)

Adds a alt-tangent tranform to the list of default `tangent_transforms` for
[`test_frule`](@ref) and [`test_rrule`](@ref) to test.
This list of defaults is overwritten by the `tangent_transforms` keyword argument.

!!! info "Transitional Feature"
ChainRulesCore v1.0 will require that all well-behaved rules work for a variety of
tangent representations. In turn, the corresponding release of ChainRulesTestUtils will
test all the different tangent representations by default.
At that stage `enable_tangent_transform!(Thunk)` will have no effect, as it will already
be enabled.
We provide this configuration as a transitional feature to help migrate your packages
one feature at a time, prior to the breaking release of ChainRulesTestUtils that will
enforce it.
"""
function enable_tangent_transform!(::Type{Thunk})
push!(TRANSFORMS_TO_ALT_TANGENTS, x->@thunk(x))
unique!(TRANSFORMS_TO_ALT_TANGENTS)
end
const TRANSFORMS_TO_ALT_TANGENTS = Function[x->@thunk(x)] # e.g. [_ -> ZeroTangent(), x -> rebasis(x)]

"sets up TEST_INFERRED based ion enviroment variables"
function init_test_inferred_setting!()
Expand Down
74 changes: 9 additions & 65 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ end
# Keyword Arguments
- `output_tangent` tangent to test accumulation of derivatives against
should be a differential for the output of `f`. Is set automatically if not provided.
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
transform the passed argument tangents into alternative tangents that should be tested.
Note that the alternative tangents are only tested for not erroring when passed to
frule. Testing for correctness using finite differencing can be done using a
separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
`test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
`frule`). Used for testing gradients from AD systems.
Expand All @@ -104,7 +98,6 @@ function test_frule(
f,
args...;
output_tangent=Auto(),
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
fdm=_fdm,
frule_f=ChainRulesCore.frule,
check_inferred::Bool=true,
Expand Down Expand Up @@ -136,41 +129,16 @@ function test_frule(
Ω = call_on_copy(primals...)
test_approx(Ω_ad, Ω; isapprox_kwargs...)

# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
is_ignored = isa.(tangents, Union{Nothing,NoTangent})
if any(tangents .== nothing)
Base.depwarn(
"test_frule(f, k ⊢ nothing) is deprecated, use " *
"test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
:test_frule,
)
end

# Correctness testing via finite differencing.
is_ignored = isa.(tangents, NoTangent)
dΩ_fd = _make_jvp_call(fdm, call_on_copy, Ω, primals, tangents, is_ignored)
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)

acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
_test_add!!_behaviour(acc, dΩ_ad; isapprox_kwargs...)

# test that rules work for other tangents
_test_frule_alt_tangents(
call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
isapprox_kwargs...
)
end # top-level testset
end

function _test_frule_alt_tangents(
call, frule_f, config, tangent_transforms, tangents, primals, acc;
isapprox_kwargs...
)
@testset "ȧrgs = $(_string_typeof(tsf.(tangents)))" for tsf in tangent_transforms
_, dΩ = call(frule_f, config, tsf.(tangents), primals...)
_test_add!!_behaviour(acc, dΩ; isapprox_kwargs...)
end
end

"""
test_rrule([config::RuleConfig,] f, args...; kwargs...)

Expand All @@ -185,12 +153,8 @@ end
# Keyword Arguments
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
should be a differential for the output of `f`. Is set automatically if not provided.
- `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
transform the passed `output_tangent` into alternative tangents that should be tested.
Note that the alternative tangents are only tested for not erroring when passed to
rrule. Testing for correctness using finite differencing can be done using a
separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
`test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
output tangent to the pullback returns the same result.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
Used for testing gradients from AD systems.
Expand All @@ -209,7 +173,7 @@ function test_rrule(
f,
args...;
output_tangent=Auto(),
tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS,
check_thunked_output_tangent=true,
fdm=_fdm,
rrule_f=ChainRulesCore.rrule,
check_inferred::Bool=true,
Expand Down Expand Up @@ -254,22 +218,13 @@ function test_rrule(
)

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
is_ignored = isa.(accum_cotangents, Union{Nothing, NoTangent})
if any(accum_cotangents .== nothing)
Base.depwarn(
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
"test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
:test_rrule,
)
end

is_ignored = isa.(accum_cotangents, NoTangent)
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)

for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
accum_cotangents, ad_cotangents, fd_cotangents
)
if accum_cotangent isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
@assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
ad_cotangent isa ZeroTangent && error(
"The pullback in the rrule should use NoTangent()" *
Expand All @@ -285,21 +240,10 @@ function test_rrule(
end
end

# test other tangents don't error when passed to the pullback
_test_rrule_alt_tangents(pullback, tangent_transforms, ȳ, accum_cotangents)
end # top-level testset
end

function _test_rrule_alt_tangents(
pullback, tangent_transforms, ȳ, accum_cotangents;
isapprox_kwargs...
)
@testset "ȳ = $(_string_typeof(tsf(ȳ)))" for tsf in tangent_transforms
ad_cotangents = pullback(tsf(ȳ))
for (accum_cotangent, ad_cotangent) in zip(accum_cotangents, ad_cotangents)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
if check_thunked_output_tangent
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
end
end
end # top-level testset
end

"""
Expand Down
Loading