From 407af5668e186331733b2b79523436011dfe86d0 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 27 Apr 2022 20:05:39 +0200 Subject: [PATCH 1/7] Add diffrules for `log1pmx` and `logmxp1` --- src/rules.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index d1ca0a2..75f2628 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -244,6 +244,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.log1mexp(x) = :(-exp($x - LogExpFunctions.log1mexp($x))) @define_diffrule LogExpFunctions.log2mexp(x) = :(-exp($x - LogExpFunctions.log2mexp($x))) @define_diffrule LogExpFunctions.logexpm1(x) = :(exp($x - LogExpFunctions.logexpm1($x))) +@define_diffrule LogExpFunctions.log1pmx(x) = :(-$x / (1 + $x)) +@define_diffrule LogExpFunctions.logmxp1(x) = :(inv($x) - 1) # binary @define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y) From 0646180edd6b8133cabe82a5247f5828827dbef4 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 27 Apr 2022 20:59:30 +0200 Subject: [PATCH 2/7] Check for errors --- test/runtests.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6618a36..1f5eea1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,7 +40,11 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, # We're happy with types with the correct promotion behavior, e.g. # it's fine to return `1` as a derivative despite input being `Float64`. @test promote_type(typeof($deriv), $T) === $T - @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 + if $(f in (:log1pmx, :logmxp1) && T == Float32 + @test_throws MethodError $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 + else + @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 + end # test for 2pi functions if $(f === :mod2pi) goo = 4 * pi From 5354387bd81a775cb51fd80327a5312e75e13619 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 27 Apr 2022 21:06:13 +0200 Subject: [PATCH 3/7] Fix typo --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1f5eea1..12bd003 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,7 +40,7 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, # We're happy with types with the correct promotion behavior, e.g. # it's fine to return `1` as a derivative despite input being `Float64`. @test promote_type(typeof($deriv), $T) === $T - if $(f in (:log1pmx, :logmxp1) && T == Float32 + if $(f in (:log1pmx, :logmxp1)) && T == Float32 @test_throws MethodError $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 else @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 From b1795d3611541ab6ba107b84e1af2ef327dc5128 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 27 Apr 2022 21:27:37 +0200 Subject: [PATCH 4/7] Fix test --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 12bd003..251ec61 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,8 +40,8 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, # We're happy with types with the correct promotion behavior, e.g. # it's fine to return `1` as a derivative despite input being `Float64`. @test promote_type(typeof($deriv), $T) === $T - if $(f in (:log1pmx, :logmxp1)) && T == Float32 - @test_throws MethodError $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 + if $(f in (:log1pmx, :logmxp1)) && $T == Float32 + @test_throws MethodError finitediff($M.$f, goo) else @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 end From 193c358b3307fb94242c988d42016e2ead757ccb Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 28 Apr 2022 10:16:30 +0200 Subject: [PATCH 5/7] Add `logmxp1` to branch that adds 0.5 to the input --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 251ec61..c65065f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,7 +27,7 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, goo = if $(f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) # avoid singularities with finite differencing rand($T) + $T(1.5) - elseif $(f in (:log, :airyaix, :airyaiprimex)) + elseif $(f in (:log, :airyaix, :airyaiprimex, :logmxp1)) # avoid singularities with finite differencing rand($T) + $T(0.5) elseif $(f === :log1mexp) From 4370be831cfc6db6f557afd2e58909418ff33f48 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 28 Apr 2022 10:17:59 +0200 Subject: [PATCH 6/7] Add comment on why finite differencing fails here --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index c65065f..ba687a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,8 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, # it's fine to return `1` as a derivative despite input being `Float64`. @test promote_type(typeof($deriv), $T) === $T if $(f in (:log1pmx, :logmxp1)) && $T == Float32 + # These two functions currently don't have fallbacks for `Real` + # arguments, nor optimized implementations for `Float32` @test_throws MethodError finitediff($M.$f, goo) else @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 From d82f0ecb6f4ca3c64bbdefbf001ae9e159041267 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Thu, 28 Apr 2022 14:55:42 +0200 Subject: [PATCH 7/7] Update src/rules.jl Co-authored-by: David Widmann --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 75f2628..45104aa 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -245,7 +245,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.log2mexp(x) = :(-exp($x - LogExpFunctions.log2mexp($x))) @define_diffrule LogExpFunctions.logexpm1(x) = :(exp($x - LogExpFunctions.logexpm1($x))) @define_diffrule LogExpFunctions.log1pmx(x) = :(-$x / (1 + $x)) -@define_diffrule LogExpFunctions.logmxp1(x) = :(inv($x) - 1) +@define_diffrule LogExpFunctions.logmxp1(x) = :((1 - $x) / $x) # binary @define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y)