Skip to content

log1pmx and logmxp1 cannot be differentiated by ForwardDiff #44

@simsurace

Description

@simsurace
using LogExpFunctions

using ForwardDiff
using Zygote
using Test

log1mexpm(x) = log1mexp(-x)
log2mexpm(x) = log2mexp(-x)

function test_ad_vs_zygote(AD)
    @testset verbose = true "LogExpFunctions" begin
        @testset "Single-argument functions" begin
            @testset "$f" for f in (                      
                log1pmx,
                logexpm1,
                logsumexp,
                xexpx,
                log1psq,
                logistic,
                invsoftplus,
                log2mexpm,
                logit,
                log1mexpm,
                logmxp1,
                xlogx,
                log1pexp,
                logcosh
            )
                for _ in 1:100
                    par = rand()
                    @test AD.derivative(f, par)  only(only(Zygote.gradient(f  only, [par])))
                end
            end
        end
        @testset "Two-argument functions" begin
            @testset "$f" for f in (                      
                xexpy,
                xlogy,
                xlog1py,
                logaddexp,
                logsubexp,
            )
                for _ in 1:100
                    par = rand(2)
                    @test AD.gradient(x -> f(x...), par)  only(Zygote.gradient(x -> f(x...), par))
                end
            end
        end
        @testset "Vector-argument functions" begin
                @testset "$f" for f in (                      
                softmax,
            )
                for _ in 1:100
                    par = rand(3)
                    @test AD.jacobian(f, par)  only(Zygote.jacobian(f, par))
                end
            end
        end
    end
    return nothing
end

test_ad_vs_zygote(ForwardDiff)

yields

Test Summary:               | Pass  Error  Total
LogExpFunctions             | 1800    200   2000
  Single-argument functions | 1200    200   1400
    log1pmx                 |         100    100
    logexpm1                |  100           100
    logsumexp               |  100           100
    xexpx                   |  100           100
    log1psq                 |  100           100
    logistic                |  100           100
    logexpm1                |  100           100
    log2mexpm               |  100           100
    logit                   |  100           100
    log1mexpm               |  100           100
    logmxp1                 |         100    100
    xlogx                   |  100           100
    log1pexp                |  100           100
    logcosh                 |  100           100
  Two-argument functions    |  500           500
  Vector-argument functions |  100           100
ERROR: Some tests did not pass: 1800 passed, 0 failed, 200 errored, 0 broken.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions