-
Notifications
You must be signed in to change notification settings - Fork 22
Closed
JuliaDiff/DiffRules.jl
#82Description
using LogExpFunctions
using ForwardDiff
using Zygote
using Test
log1mexpm(x) = log1mexp(-x)
log2mexpm(x) = log2mexp(-x)
function test_ad_vs_zygote(AD)
@testset verbose = true "LogExpFunctions" begin
@testset "Single-argument functions" begin
@testset "$f" for f in (
log1pmx,
logexpm1,
logsumexp,
xexpx,
log1psq,
logistic,
invsoftplus,
log2mexpm,
logit,
log1mexpm,
logmxp1,
xlogx,
log1pexp,
logcosh
)
for _ in 1:100
par = rand()
@test AD.derivative(f, par) ≈ only(only(Zygote.gradient(f ∘ only, [par])))
end
end
end
@testset "Two-argument functions" begin
@testset "$f" for f in (
xexpy,
xlogy,
xlog1py,
logaddexp,
logsubexp,
)
for _ in 1:100
par = rand(2)
@test AD.gradient(x -> f(x...), par) ≈ only(Zygote.gradient(x -> f(x...), par))
end
end
end
@testset "Vector-argument functions" begin
@testset "$f" for f in (
softmax,
)
for _ in 1:100
par = rand(3)
@test AD.jacobian(f, par) ≈ only(Zygote.jacobian(f, par))
end
end
end
end
return nothing
end
test_ad_vs_zygote(ForwardDiff)yields
Test Summary: | Pass Error Total
LogExpFunctions | 1800 200 2000
Single-argument functions | 1200 200 1400
log1pmx | 100 100
logexpm1 | 100 100
logsumexp | 100 100
xexpx | 100 100
log1psq | 100 100
logistic | 100 100
logexpm1 | 100 100
log2mexpm | 100 100
logit | 100 100
log1mexpm | 100 100
logmxp1 | 100 100
xlogx | 100 100
log1pexp | 100 100
logcosh | 100 100
Two-argument functions | 500 500
Vector-argument functions | 100 100
ERROR: Some tests did not pass: 1800 passed, 0 failed, 200 errored, 0 broken.Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels