diff --git a/Project.toml b/Project.toml index c4b03a24..f1febd46 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogExpFunctions" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" authors = ["StatsFun.jl contributors, Tamas K. Papp "] -version = "0.3.12" +version = "0.3.13" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/basicfuns.jl b/src/basicfuns.jl index e30809c2..5b989c37 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -246,6 +246,7 @@ $(SIGNATURES) Return `log(1 + x) - x`. Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`. +This will fall back to the naive calculation for argument types different from `Float64`. """ function log1pmx(x::Float64) if !(-0.7 < x < 0.9) @@ -267,10 +268,14 @@ function log1pmx(x::Float64) end end +# Naive fallback +log1pmx(x::Real) = log1p(x) - x + """ $(SIGNATURES) Return `log(x) - x + 1` carefully evaluated. +This will fall back to the naive calculation for argument types different from `Float64`. """ function logmxp1(x::Float64) if x <= 0.3 @@ -286,6 +291,17 @@ function logmxp1(x::Float64) end end +# Naive fallback +function logmxp1(x::Real) + one_x = one(x) + if 2 * x < one_x + # for small values of `x` the other branch returns non-finite values + return (log(x) + one_x) - x + else + return log1pmx(x - one_x) + end +end + # The kernel of log1pmx # Accuracy within ~2ulps for -0.227 < x < 0.315 function _log1pmx_ker(x::Float64) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 4479080d..cccaf876 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -177,12 +177,28 @@ end @test iszero(log1pmx(0.0)) @test log1pmx(1.0) ≈ log(2.0) - 1.0 @test log1pmx(2.0) ≈ log(3.0) - 2.0 + + @test iszero(log1pmx(0f0)) + @test log1pmx(1f0) ≈ log(2f0) - 1f0 + @test log1pmx(2f0) ≈ log(3f0) - 2f0 + + for x in -0.5:0.1:10 + @test log1pmx(Float32(x)) ≈ Float32(log1pmx(x)) + end end @testset "logmxp1" begin @test iszero(logmxp1(1.0)) @test logmxp1(2.0) ≈ log(2.0) - 1.0 @test logmxp1(3.0) ≈ log(3.0) - 2.0 + + @test iszero(logmxp1(1f0)) + @test logmxp1(2f0) ≈ log(2f0) - 1f0 + @test logmxp1(3f0) ≈ log(3f0) - 2f0 + + for x in 0.1:0.1:10 + @test logmxp1(Float32(x)) ≈ Float32(logmxp1(x)) + end end @testset "logsumexp" begin