From e566656ffa4efed295b85bef2207774c79db4f96 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sun, 20 Dec 2020 03:00:12 +0100 Subject: [PATCH 1/4] add ifelse and muladd rules --- src/rules.jl | 10 ++++++++++ test/runtests.jl | 32 +++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index e31f5d4..63bfd14 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -60,10 +60,12 @@ @define_diffrule Base.deg2rad(x) = :( π / 180 ) @define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 ) @define_diffrule Base.rad2deg(x) = :( 180 / π ) + @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) + @define_diffrule Base.transpose(x) = :( 1 ) @define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) @@ -93,6 +95,14 @@ end @define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) ) @define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) ) +# trinary # +#---------# + +@define_diffrule Base.muladd(x, y, z) = :($y), :($x), :(one($z)) +@define_diffrule Base.fma(x, y, z) = :($y), :($x), :(one($z)) + +@define_diffrule Base.ifelse(p, x, y) = false, :($p), :(!$p) + #################### # SpecialFunctions # #################### diff --git a/test/runtests.jl b/test/runtests.jl index ede8617..c41c29e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,7 @@ function finitediff(f, x) end -non_numeric_arg_functions = [(:Base, :rem2pi, 2)] +non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)] for (M, f, arity) in DiffRules.diffrules() (M, f, arity) ∈ non_numeric_arg_functions && continue @@ -46,6 +46,22 @@ for (M, f, arity) in DiffRules.diffrules() @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) end end + elseif arity == 3 + @test DiffRules.hasdiffrule(M, f, 3) + derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) + @eval begin + foo, bar, goo = randn(3) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + if !(isnan(dx)) + @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) + end + if !(isnan(dy)) + @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) + end + if !(isnan(dz)) + @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) + end + end end end @@ -62,3 +78,17 @@ for xtype in [:Float64, :BigFloat, :Int64] end end end + +# Test ifelse separately as first argument is boolean +@test DiffRules.hasdiffrule(:Base, :ifelse, 3) +derivs = DiffRules.diffrule(:Base, :ifelse, :foo, :bar, :goo) +for cond in [true, false] + @eval begin + foo = $cond + bar, gee = randn(2) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + @test isapprox(dy, finitediff(y -> ifelse(foo, y, goo), bar), rtol=0.05) + @test isapprox(dz, finitediff(z -> ifelse(foo, bar, z), goo), rtol=0.05) + end +end + From 391f5145534a681e721aef1385345bc3a7aadd4d Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sun, 20 Dec 2020 04:36:53 +0100 Subject: [PATCH 2/4] add 2-arg log, too, closes https://github.com/JuliaDiff/DiffRules.jl/issues/32 --- src/rules.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index 63bfd14..9c8e43a 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -89,6 +89,8 @@ else @define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) ) end @define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) ) +@define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) + @define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) ) @define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) ) @define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN From cbf17ea233d16deb6f0a32bca661c63a79126adc Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sun, 20 Dec 2020 04:47:07 +0100 Subject: [PATCH 3/4] trivial 1-arg functions --- src/rules.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index 9c8e43a..9c0fb90 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -66,6 +66,9 @@ @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) +@define_diffrule Base.identity(x) = :( 1 ) +@define_diffrule Base.conj(x) = :( 1 ) +@define_diffrule Base.adjoint(x) = :( 1 ) @define_diffrule Base.transpose(x) = :( 1 ) @define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) From 4a849ea16670921226d71f095bf6f8dd523d6e85 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 4 Jul 2021 13:53:13 -0400 Subject: [PATCH 4/4] v1.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a93c816..6b4bdba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DiffRules" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.0.2" +version = "1.1.0" [deps] NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"