From 90a369db69956436fd5dfd04892ddce004505a10 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:56:02 -0700 Subject: [PATCH 01/29] Fix indentation --- src/ChainRules.jl | 1 + test/rulesets/Base/base.jl | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 4af705dc3..392080e5d 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -21,6 +21,7 @@ if VERSION < v"1.3.0-DEV.142" import LinearAlgebra: dot end +include("rulesets/Base/utils.jl") include("rulesets/Base/base.jl") include("rulesets/Base/fastmath_able.jl") include("rulesets/Base/array.jl") diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index ee016d3c3..ddaa9d636 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -95,17 +95,18 @@ @test extern(dy) == extern(zeros(2, 5) .+ dy) end - @testset "ldexp" begin - x, Δx, x̄ = 10rand(3) - Δz = rand() - - for n in (0,1,20) - # TODO: Forward test does not work when parameter is Integer - # See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/22 - #frule_test(ldexp, (x, Δx), (n, nothing)) - rrule_test(ldexp, Δz, (x, x̄), (n, nothing)) - end - end + @testset "ldexp" begin + x, Δx, x̄ = 10rand(3) + Δz = rand() + + for n in (0,1,20) + # TODO: Forward test does not work when parameter is Integer + # See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/22 + #frule_test(ldexp, (x, Δx), (n, nothing)) + rrule_test(ldexp, Δz, (x, x̄), (n, nothing)) + end + end + @testset "binary function ($f)" for f in (mod, \) x, Δx, x̄ = 10rand(3) From a49187bbea6a854afd899a5069846f3b2d1cf780 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:56:41 -0700 Subject: [PATCH 02/29] Test \ on complex inputs --- test/rulesets/Base/base.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index ddaa9d636..a768e29fb 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -107,14 +107,19 @@ end end + @testset "\\(x::$T, y::$T) (scalar)" for T in (Float64, ComplexF64) + x, ẋ, x̄, y, ẏ, ȳ, Δz = randn(T, 7) + frule_test(*, (x, ẋ), (y, ẏ)) + rrule_test(*, Δz, (x, x̄), (y, ȳ)) + end - @testset "binary function ($f)" for f in (mod, \) + @testset "mod" begin x, Δx, x̄ = 10rand(3) y, Δy, ȳ = rand(3) Δz = rand() - frule_test(f, (x, Δx), (y, Δy)) - rrule_test(f, Δz, (x, x̄), (y, ȳ)) + frule_test(mod, (x, Δx), (y, Δy)) + rrule_test(mod, Δz, (x, x̄), (y, ȳ)) end @testset "x^n for x<0" begin From aa26f0456d61c83bc805ebf7416fed4a2d7f3c38 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:57:03 -0700 Subject: [PATCH 03/29] Test ^ on complex inputs --- test/rulesets/Base/base.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a768e29fb..9d6ca979d 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -122,14 +122,15 @@ rrule_test(mod, Δz, (x, x̄), (y, ȳ)) end - @testset "x^n for x<0" begin - x = -15*rand() - Δx, x̄ = 10rand(2) - y, Δy, ȳ = rand(3) - Δz = rand() - - frule_test(^, (-x, Δx), (y, Δy)) - rrule_test(^, Δz, (-x, x̄), (y, ȳ)) + @testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64) + # for real x and n, x must be >0 + x = T <: Real ? 15rand() : 15randn(ComplexF64) + Δx, x̄ = 10rand(T, 2) + y, Δy, ȳ = rand(T, 3) + Δz = rand(T) + + frule_test(^, (x, Δx), (y, Δy)) + rrule_test(^, Δz, (x, x̄), (y, ȳ)) end @testset "identity" begin From e47afd28794d9aa3d56981a79c0af8404506bb75 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:57:28 -0700 Subject: [PATCH 04/29] Test identity on complex inputs --- test/rulesets/Base/base.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 9d6ca979d..b7bc613ef 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -133,13 +133,13 @@ rrule_test(^, Δz, (x, x̄), (y, ȳ)) end - @testset "identity" begin - rrule_test(identity, randn(), (randn(), randn())) - rrule_test(identity, randn(4), (randn(4), randn(4))) + @testset "identity" for T in (Float64, ComplexF64) + rrule_test(identity, randn(T), (randn(T), randn(T))) + rrule_test(identity, randn(T, 4), (randn(T, 4), randn(T, 4))) rrule_test( - identity, Tuple(randn(3)), - (Composite{Tuple}(randn(3)...), Composite{Tuple}(randn(3)...)) + identity, Tuple(randn(T, 3)), + (Composite{Tuple}(randn(T, 3)...), Composite{Tuple}(randn(T, 3)...)) ) end From beaae0c65f5aee29739280bceb225f4a65ecbbe9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:57:53 -0700 Subject: [PATCH 05/29] Test muladd on complex inputs --- test/rulesets/Base/base.jl | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index b7bc613ef..ec46a4a8e 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -148,15 +148,26 @@ test_scalar(zero, x) end - @testset "trinary ($f)" for f in (muladd, fma) + @testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64) + x, Δx, x̄ = 10randn(T, 3) + y, Δy, ȳ = randn(T, 3) + z, Δz, z̄ = randn(T, 3) + Δk = randn(T) + + frule_test(muladd, (x, Δx), (y, Δy), (z, Δz)) + rrule_test(muladd, Δk, (x, x̄), (y, ȳ), (z, z̄)) + end + + @testset "fma" begin x, Δx, x̄ = 10randn(3) y, Δy, ȳ = randn(3) z, Δz, z̄ = randn(3) Δk = randn() - frule_test(f, (x, Δx), (y, Δy), (z, Δz)) - rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄)) + frule_test(fma, (x, Δx), (y, Δy), (z, Δz)) + rrule_test(fma, Δk, (x, x̄), (y, ȳ), (z, z̄)) end + @testset "clamp" begin x̄, ȳ, z̄ = randn(3) Δx, Δy, Δz = randn(3) From b89b588125858f9b694ed4cb289e36a8148ccc99 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:58:53 -0700 Subject: [PATCH 06/29] Test binary functions on complex inputs --- test/rulesets/Base/fastmath_able.jl | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 42bbec3b7..a528e3a54 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -109,13 +109,24 @@ const FASTABLE_AST = quote end end - @testset "binary function ($f)" for f in (/, +, -, hypot, atan, rem, ^, max, min) - x, Δx, x̄ = 10rand(3) - y, Δy, ȳ = rand(3) - Δz = rand() + @testset "binary functions" begin + @testset "$f(x, y)" for f in (atan, rem, max, min) + x, Δx, x̄ = 10rand(3) + y, Δy, ȳ = rand(3) + Δz = rand() + + frule_test(f, (x, Δx), (y, Δy)) + rrule_test(f, Δz, (x, x̄), (y, ȳ)) + end + + @testset "$f(x::$T, y::$T)" for f in (/, +, -, hypot), T in (Float64, ComplexF64) + x, Δx, x̄ = 10rand(T, 3) + y, Δy, ȳ = rand(T, 3) + Δz = randn(typeof(f(x, y))) - frule_test(f, (x, Δx), (y, Δy)) - rrule_test(f, Δz, (x, x̄), (y, ȳ)) + frule_test(f, (x, Δx), (y, Δy)) + rrule_test(f, Δz, (x, x̄), (y, ȳ)) + end end @testset "sign" begin From 4738c48041a34ee7b4e2a57faea1cdc3d72245d0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:59:13 -0700 Subject: [PATCH 07/29] Test functions on complex inputs --- test/rulesets/Base/base.jl | 4 ++-- test/rulesets/Base/fastmath_able.jl | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index ec46a4a8e..e16972511 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -46,14 +46,14 @@ end # Trig @testset "Angles" begin - for x in (-0.1, 6.4) + for x in (-0.1, 6.4, 0.5 + 0.25im) test_scalar(deg2rad, x) test_scalar(rad2deg, x) end end @testset "Unary complex functions" begin - for x in (-4.1, 6.4) + for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) test_scalar(real, x) test_scalar(imag, x) test_scalar(hypot, x) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index a528e3a54..fa0b108fe 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -55,9 +55,9 @@ const FASTABLE_AST = quote test_scalar(atan, x) end @testset "Multivariate" begin - @testset "sincos" begin - x, Δx, x̄ = randn(3) - Δz = (randn(), randn()) + @testset "sincos(x::$T)" for T in (Float64, ComplexF64) + x, Δx, x̄ = randn(T, 3) + Δz = (randn(T), randn(T)) frule_test(sincos, (x, Δx)) rrule_test(sincos, Δz, (x, x̄)) @@ -66,7 +66,7 @@ const FASTABLE_AST = quote end @testset "exponents" begin - for x in (-0.1, 6.4) + for x in (-0.1, 6.4, 0.5 + 0.25im) test_scalar(inv, x) test_scalar(exp, x) @@ -74,9 +74,11 @@ const FASTABLE_AST = quote test_scalar(exp10, x) test_scalar(expm1, x) - test_scalar(cbrt, x) + if x isa Real + test_scalar(cbrt, x) + end - if x >= 0 + if x isa Complex || x >= 0 test_scalar(sqrt, x) test_scalar(log, x) test_scalar(log2, x) @@ -103,9 +105,10 @@ const FASTABLE_AST = quote end @testset "Unary functions" begin - for x in (-4.1, 6.4) + for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) test_scalar(+, x) test_scalar(-, x) + test_scalar(atan, x) end end From 0f341930567408489b378a1303b58182bfb9c50e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 14:59:46 -0700 Subject: [PATCH 08/29] Release type constraint on exp --- src/rulesets/Base/fastmath_able.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 559b64726..fe71d46f0 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -25,7 +25,7 @@ let @scalar_rule cbrt(x) inv(3 * Ω ^ 2) @scalar_rule inv(x) -(Ω ^ 2) @scalar_rule sqrt(x) inv(2Ω) - @scalar_rule exp(x::Real) Ω + @scalar_rule exp(x) Ω @scalar_rule exp10(x) Ω * log(oftype(x, 10)) @scalar_rule exp2(x) Ω * log(oftype(x, 2)) @scalar_rule expm1(x) exp(x) From 0dd4023640fa63213acbebca3f5cdbc700518289 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:01:45 -0700 Subject: [PATCH 09/29] Add _realconjtimes --- src/rulesets/Base/utils.jl | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/rulesets/Base/utils.jl diff --git a/src/rulesets/Base/utils.jl b/src/rulesets/Base/utils.jl new file mode 100644 index 000000000..0b112e750 --- /dev/null +++ b/src/rulesets/Base/utils.jl @@ -0,0 +1,5 @@ +# real(x * conj(y)) avoiding computing the imaginary part +_realconjtimes(x, y) = real(x) * real(y) + imag(x) * imag(y) +_realconjtimes(x::Real, y) = x * real(y) +_realconjtimes(x, y::Real) = real(x) * y +_realconjtimes(x::Real, y::Real) = x * y From d737cdf46edadeb7e77220b7c379662f6e6beb46 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:02:23 -0700 Subject: [PATCH 10/29] Use _realconjtimes in abs/abs2 rules --- src/rulesets/Base/fastmath_able.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index fe71d46f0..6d74ebadd 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -42,7 +42,7 @@ let end function frule((_, Δz), ::typeof(abs), z::Complex) Ω = abs(z) - return Ω, (real(z) * real(Δz) + imag(z) * imag(Δz)) / ifelse(iszero(z), one(Ω), Ω) + return Ω, _realconjtimes(z, Δz) / ifelse(iszero(z), one(Ω), Ω) # `ifelse` is applied only to denominator to ensure type-stability. end @@ -63,11 +63,8 @@ let end ## abs2 - function frule((_, Δx), ::typeof(abs2), x::Real) - return abs2(x), 2x * real(Δx) - end - function frule((_, Δz), ::typeof(abs2), z::Complex) - return abs2(z), 2 * (real(z) * real(Δz) + imag(z) * imag(Δz)) + function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex}) + return abs2(z), 2 * _realconjtimes(z, Δz) end function rrule(::typeof(abs2), z::Union{Real, Complex}) From e5276e66019e4346ac227d8506dad076373833f9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:02:38 -0700 Subject: [PATCH 11/29] Add complex rule for hypot --- src/rulesets/Base/fastmath_able.jl | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 6d74ebadd..91d0afb61 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -123,7 +123,30 @@ let end # Binary functions - @scalar_rule hypot(x::Real, y::Real) (x / Ω, y / Ω) + + # `hypot` + + function frule( + (_, Δx, Δy), + ::typeof(hypot), + x::T, + y::T, + ) where {T<:Union{Real,Complex}} + Ω = hypot(x, y) + n = ifelse(iszero(Ω), one(Ω), Ω) + ∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n + return Ω, ∂Ω + end + + function rrule(::typeof(hypot), x::T, y::T) where {T<:Union{Real,Complex}} + Ω = hypot(x, y) + function hypot_pullback(ΔΩ) + c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) + return (NO_FIELDS, @thunk(c * x), @thunk(c * y)) + end + return (Ω, hypot_pullback) + end + @scalar_rule x + y (One(), One()) @scalar_rule x - y (One(), -1) @scalar_rule x / y (inv(y), -((x / y) / y)) From ac58495853c30d3c19f8244be0d12f5fe0b59fe6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:03:14 -0700 Subject: [PATCH 12/29] Add generic rule for adjoint --- src/rulesets/Base/base.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 7f474e0c5..6b9e66792 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -3,8 +3,17 @@ @scalar_rule one(x) zero(x) @scalar_rule zero(x) zero(x) -@scalar_rule adjoint(x::Real) One() @scalar_rule transpose(x) One() + +# `adjoint` + +frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') + +function rrule(::typeof(adjoint), z::Number) + adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ') + return (z', adjoint_pullback) +end + @scalar_rule imag(x::Real) Zero() @scalar_rule hypot(x::Real) sign(x) From 7f6c709999e4219b9d8e9e5929f7d646b0d5b9d7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:23:41 -0700 Subject: [PATCH 13/29] Add generic rule for real --- src/rulesets/Base/base.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6b9e66792..8789efd09 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -14,13 +14,26 @@ function rrule(::typeof(adjoint), z::Number) return (z', adjoint_pullback) end +# `real` + +@scalar_rule real(x::Real) One() + +frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz)) + +function rrule(::typeof(real), z::Number) + # add zero(z) to embed the real number in the same number type as z + real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z)) + return (real(z), real_pullback) +end + +# `imag` + @scalar_rule imag(x::Real) Zero() @scalar_rule hypot(x::Real) sign(x) @scalar_rule fma(x, y, z) (y, x, One()) @scalar_rule muladd(x, y, z) (y, x, One()) -@scalar_rule real(x::Real) One() @scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist()) @scalar_rule( mod(x, y), From b5fef9e5ac56a02339c7a081f3ee066db1ef67e5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:25:21 -0700 Subject: [PATCH 14/29] Add generic rule for imag --- src/rulesets/Base/base.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 8789efd09..1b51901d8 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -29,6 +29,14 @@ end # `imag` @scalar_rule imag(x::Real) Zero() + +frule((_, Δz), ::typeof(imag), z::Number) = (imag(z), imag(Δz)) + +function rrule(::typeof(imag), z::Complex) + imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im) + return (imag(z), imag_pullback) +end + @scalar_rule hypot(x::Real) sign(x) From 45ba9b75bf919137f09ead9832d8caaab91ac7a6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:26:48 -0700 Subject: [PATCH 15/29] Add complex rule for hypot --- src/rulesets/Base/base.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 1b51901d8..ad2e72c16 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -37,8 +37,23 @@ function rrule(::typeof(imag), z::Complex) return (imag(z), imag_pullback) end +# `hypot` + @scalar_rule hypot(x::Real) sign(x) +function frule((_, Δz), ::typeof(hypot), z::Complex) + Ω = hypot(z) + ∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω) + return Ω, ∂Ω +end + +function rrule(::typeof(hypot), z::Complex) + Ω = hypot(z) + function hypot_pullback(ΔΩ) + return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) + end + return (Ω, hypot_pullback) +end @scalar_rule fma(x, y, z) (y, x, One()) @scalar_rule muladd(x, y, z) (y, x, One()) From 5971f4f61b03a2a19f29c046aebce928b74290cd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 15:59:33 -0700 Subject: [PATCH 16/29] Add rules/tests for Complex --- src/rulesets/Base/base.jl | 20 ++++++++++++++++++++ test/rulesets/Base/base.jl | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index ad2e72c16..9af7f1745 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -37,6 +37,26 @@ function rrule(::typeof(imag), z::Complex) return (imag(z), imag_pullback) end +# `Complex` + +frule((_, Δz), ::Type{T}, z::Number) where {T<:Complex} = (T(z), Complex(Δz)) +function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex} + return (T(x, y), Complex(Δx, Δy)) +end + +function rrule(::Type{T}, z::Complex) where {T<:Complex} + Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ)) + return (T(z), Complex_pullback) +end +function rrule(::Type{T}, x::Real) where {T<:Complex} + Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ)) + return (T(x), Complex_pullback) +end +function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} + Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ)) + return (T(x, y), Complex_pullback) +end + # `hypot` @scalar_rule hypot(x::Real) sign(x) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index e16972511..84049ac9e 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -61,6 +61,16 @@ end end + @testset "Complex" begin + test_scalar(Complex, randn()) + test_scalar(Complex, randn(ComplexF64)) + x, ẋ, x̄ = randn(3) + y, ẏ, ȳ = randn(3) + Δz = randn(ComplexF64) + frule_test(Complex, (x, ẋ), (y, ẏ)) + rrule_test(Complex, Δz, (x, x̄), (y, ȳ)) + end + @testset "*(x, y) (scalar)" begin # This is pretty important so testing it fairly heavily test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im) From 45b2edc64f14fd89e3b1637b016dfa736d08ee1a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 16:34:32 -0700 Subject: [PATCH 17/29] Test frule for identity --- test/rulesets/Base/base.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 84049ac9e..45c1c516e 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -144,9 +144,15 @@ end @testset "identity" for T in (Float64, ComplexF64) + frule_test(identity, (randn(T), randn(T))) + frule_test(identity, (randn(T, 4), randn(T, 4))) + frule_test( + identity, + (Composite{Tuple}(randn(T, 3)...), Composite{Tuple}(randn(T, 3)...)) + ) + rrule_test(identity, randn(T), (randn(T), randn(T))) rrule_test(identity, randn(T, 4), (randn(T, 4), randn(T, 4))) - rrule_test( identity, Tuple(randn(T, 3)), (Composite{Tuple}(randn(T, 3)...), Composite{Tuple}(randn(T, 3)...)) From c67819737f4ef475b0a0ea8e58b78328b512ba55 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Jun 2020 16:46:43 -0700 Subject: [PATCH 18/29] Add missing angle test --- test/rulesets/Base/fastmath_able.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index fa0b108fe..36dd4c2ea 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -102,6 +102,15 @@ const FASTABLE_AST = quote end @test frule((Zero(), randn()), angle, randn())[2] === Zero() @test rrule(angle, randn())[2](randn())[2] === Zero() + + # test that real primal with complex tangent gives complex tangent + ΔΩ = randn(ComplexF64) + for x in (-0.5, 2.0) + @test isapprox( + frule((Zero(), ΔΩ), angle, x)[2], + frule((Zero(), ΔΩ), angle, complex(x))[2], + ) + end end @testset "Unary functions" begin From 3b3d11db83ac3d6a8e8fe8fa199369a707dd9aa0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 11:57:02 -0700 Subject: [PATCH 19/29] Make inline just in case --- src/rulesets/Base/utils.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/utils.jl b/src/rulesets/Base/utils.jl index 0b112e750..8406bad66 100644 --- a/src/rulesets/Base/utils.jl +++ b/src/rulesets/Base/utils.jl @@ -1,5 +1,6 @@ -# real(x * conj(y)) avoiding computing the imaginary part -_realconjtimes(x, y) = real(x) * real(y) + imag(x) * imag(y) -_realconjtimes(x::Real, y) = x * real(y) -_realconjtimes(x, y::Real) = real(x) * y -_realconjtimes(x::Real, y::Real) = x * y +# real(conj(x) * y) avoiding computing the imaginary part if possible +@inline _realconjtimes(x, y) = real(conj(x) * y) +@inline _realconjtimes(x::Complex, y::Complex) = real(x) * real(y) + imag(x) * imag(y) +@inline _realconjtimes(x::Real, y::Complex) = x * real(y) +@inline _realconjtimes(x::Complex, y::Real) = real(x) * y +@inline _realconjtimes(x::Real, y::Real) = x * y From f6f7c8aac369d9d7f67bdabd1d9a177f32e736df Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 11:57:45 -0700 Subject: [PATCH 20/29] Unify abs rules --- src/rulesets/Base/fastmath_able.jl | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 91d0afb61..73a6bef75 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -37,27 +37,18 @@ let # Unary complex functions ## abs - function frule((_, Δx), ::typeof(abs), x::Real) - return abs(x), sign(x) * real(Δx) - end - function frule((_, Δz), ::typeof(abs), z::Complex) - Ω = abs(z) - return Ω, _realconjtimes(z, Δz) / ifelse(iszero(z), one(Ω), Ω) + function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex}) + Ω = abs(x) + signx = x isa Real ? sign(x) : Ω / ifelse(iszero(x), one(Ω), Ω) # `ifelse` is applied only to denominator to ensure type-stability. + return Ω, _realconjtimes(signx, Δx) end - function rrule(::typeof(abs), x::Real) - function abs_pullback(ΔΩ) - return (NO_FIELDS, real(ΔΩ)*sign(x)) - end - return abs(x), abs_pullback - end - function rrule(::typeof(abs), z::Complex) - Ω = abs(z) + function rrule(::typeof(abs), x::Union{Real, Complex}) + Ω = abs(x) function abs_pullback(ΔΩ) - Δu = real(ΔΩ) - return (NO_FIELDS, Δu*z/ifelse(iszero(z), one(Ω), Ω)) - # `ifelse` is applied only to denominator to ensure type-stability. + signx = x isa Real ? sign(x) : Ω / ifelse(iszero(x), one(Ω), Ω) + return (NO_FIELDS, signx * real(ΔΩ)) end return Ω, abs_pullback end From 16f8307f72ec2578f6bd1e75811da88f3a2dfa49 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 11:58:02 -0700 Subject: [PATCH 21/29] Introduce _imagconjtimes utility function --- src/rulesets/Base/utils.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/rulesets/Base/utils.jl b/src/rulesets/Base/utils.jl index 8406bad66..e243a04ae 100644 --- a/src/rulesets/Base/utils.jl +++ b/src/rulesets/Base/utils.jl @@ -4,3 +4,10 @@ @inline _realconjtimes(x::Real, y::Complex) = x * real(y) @inline _realconjtimes(x::Complex, y::Real) = real(x) * y @inline _realconjtimes(x::Real, y::Real) = x * y + +# imag(conj(x) * y) avoiding computing the real part if possible +@inline _imagconjtimes(x, y) = imag(conj(x) * y) +@inline _imagconjtimes(x::Complex, y::Complex) = -imag(x) * real(y) + real(x) * imag(y) +@inline _imagconjtimes(x::Real, y::Complex) = x * imag(y) +@inline _imagconjtimes(x::Complex, y::Real) = -imag(x) * y +@inline _imagconjtimes(x::Real, y::Real) = Zero() From 86dac8257553b07622cccea50fd7758b4530acef Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 11:58:26 -0700 Subject: [PATCH 22/29] Unify angle rules --- src/rulesets/Base/fastmath_able.jl | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 73a6bef75..135ac3fcb 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -78,20 +78,13 @@ let end ## angle - function frule((_, Δz), ::typeof(angle), x::Real) - Δx, Δy = reim(Δz) - return angle(x), Δy/ifelse(iszero(x), one(x), x) - # `ifelse` is applied only to denominator to ensure type-stability. - end - function frule((_, Δz)::Tuple{<:Any, <:Real}, ::typeof(angle), x::Real) - return angle(x), Zero() - end - function frule((_, Δz), ::typeof(angle), z::Complex) - x, y = reim(z) - Δx, Δy = reim(Δz) - return angle(z), (-y*Δx + x*Δy)/ifelse(iszero(z), one(z), abs2(z)) + function frule((_, Δx), ::typeof(angle), x) + Ω = angle(x) # `ifelse` is applied only to denominator to ensure type-stability. + ∂Ω = _imagconjtimes(Δx, x) / ifelse(iszero(x), one(x), abs2(x)) + return Ω, ∂Ω end + function rrule(::typeof(angle), x::Real) function angle_pullback(ΔΩ::Real) return (NO_FIELDS, Zero()) From 071b658ad2b74cfc52556add91fb680cfca8567c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 11:58:40 -0700 Subject: [PATCH 23/29] Unify sign rules --- src/rulesets/Base/fastmath_able.jl | 37 +++++++++--------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 135ac3fcb..7d574ee3c 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -150,30 +150,21 @@ let @scalar_rule +x One() @scalar_rule -x -1 - function frule((_, Δx), ::typeof(sign), x::Real) - Ω = sign(x) - ∂Ω = _sign_jvp(Ω, x, Δx) - return Ω, ∂Ω - end - function frule((_, Δz), ::typeof(sign), z::Complex) - absz = abs(ifelse(iszero(z), one(z), z)) - Ω = z / absz - ∂Ω = _sign_jvp(Ω, absz, Δz) + # `sign` + + function frule((_, Δx), ::typeof(sign), x) + n = ifelse(iszero(x), one(x), abs(x)) + Ω = x isa Real ? sign(x) : x / n + ∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im return Ω, ∂Ω end - function rrule(::typeof(sign), x::Real) - Ω = sign(x) - function sign_pullback(ΔΩ) - return (NO_FIELDS, _sign_jvp(Ω, x, ΔΩ)) - end - return Ω, sign_pullback - end - function rrule(::typeof(sign), z::Complex) - absz = abs(ifelse(iszero(z), one(z), z)) - Ω = z / absz + function rrule(::typeof(sign), x) + n = ifelse(iszero(x), one(x), abs(x)) + Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) - return (NO_FIELDS, _sign_jvp(Ω, absz, ΔΩ)) + ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im + return (NO_FIELDS, ∂x) end return Ω, sign_pullback end @@ -200,9 +191,3 @@ let eval(fastable_ast) # Get original definitions # we do this second so it overwrites anything we included by mistake in the fastable end - -# the jacobian for `sign` is symmetric; `_sign_jvp` gives both J * Δz and Jᵀ * ΔΩ for -# output Ω, (co)tangent Δ, and real input x or the absolute value of complex input z -_sign_jvp(Ω, absz, Δ) = Ω * ((imag(Δ) * real(Ω) - real(Δ) * imag(Ω)) / absz)im -_sign_jvp(Ω::Real, x::Real, Δ) = (imag(Δ) * Ω / ifelse(iszero(x), one(x), x)) * im -_sign_jvp(Ω::Real, x::Real, Δ::Real) = Zero() From e0dd0d3a3bbd808194f82ce95434dbc9c6cc0b6a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 15:58:26 -0700 Subject: [PATCH 24/29] Multiply by correct variable --- src/rulesets/Base/fastmath_able.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 7d574ee3c..89582a511 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -39,7 +39,7 @@ let ## abs function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex}) Ω = abs(x) - signx = x isa Real ? sign(x) : Ω / ifelse(iszero(x), one(Ω), Ω) + signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) # `ifelse` is applied only to denominator to ensure type-stability. return Ω, _realconjtimes(signx, Δx) end @@ -47,7 +47,7 @@ let function rrule(::typeof(abs), x::Union{Real, Complex}) Ω = abs(x) function abs_pullback(ΔΩ) - signx = x isa Real ? sign(x) : Ω / ifelse(iszero(x), one(Ω), Ω) + signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) return (NO_FIELDS, signx * real(ΔΩ)) end return Ω, abs_pullback From 2ce61d45fd266249176c21db9b1f47504bbea584 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 15:58:34 -0700 Subject: [PATCH 25/29] Fix argument order --- src/rulesets/Base/fastmath_able.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 89582a511..8c8f1b7ef 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -81,7 +81,7 @@ let function frule((_, Δx), ::typeof(angle), x) Ω = angle(x) # `ifelse` is applied only to denominator to ensure type-stability. - ∂Ω = _imagconjtimes(Δx, x) / ifelse(iszero(x), one(x), abs2(x)) + ∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x)) return Ω, ∂Ω end From 1c97506eeca8c966968591c672a77088cf4e9d6d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 16:10:58 -0700 Subject: [PATCH 26/29] Bump ChainRulesTestUtils version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b0849033c..b9a644db7 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.9" -ChainRulesTestUtils = "0.4.1" +ChainRulesTestUtils = "0.4.2" Compat = "3" FiniteDifferences = "0.10" Reexport = "0.2" From c286f4145cabcc7a2e3b0de5f1b5d1eceb1feb88 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 17:00:57 -0700 Subject: [PATCH 27/29] Restrict to Complex --- src/rulesets/Base/base.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 9af7f1745..c70af02c8 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -30,7 +30,7 @@ end @scalar_rule imag(x::Real) Zero() -frule((_, Δz), ::typeof(imag), z::Number) = (imag(z), imag(Δz)) +frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz)) function rrule(::typeof(imag), z::Complex) imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im) From 980c413a51c1db565aa01db1d06a186a49d9cb77 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 29 Jun 2020 17:05:21 -0700 Subject: [PATCH 28/29] Use muladd --- src/rulesets/Base/utils.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/utils.jl b/src/rulesets/Base/utils.jl index e243a04ae..d957d555f 100644 --- a/src/rulesets/Base/utils.jl +++ b/src/rulesets/Base/utils.jl @@ -1,13 +1,15 @@ # real(conj(x) * y) avoiding computing the imaginary part if possible @inline _realconjtimes(x, y) = real(conj(x) * y) -@inline _realconjtimes(x::Complex, y::Complex) = real(x) * real(y) + imag(x) * imag(y) +@inline _realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) @inline _realconjtimes(x::Real, y::Complex) = x * real(y) @inline _realconjtimes(x::Complex, y::Real) = real(x) * y @inline _realconjtimes(x::Real, y::Real) = x * y # imag(conj(x) * y) avoiding computing the real part if possible @inline _imagconjtimes(x, y) = imag(conj(x) * y) -@inline _imagconjtimes(x::Complex, y::Complex) = -imag(x) * real(y) + real(x) * imag(y) +@inline function _imagconjtimes(x::Complex, y::Complex) + return muladd(-imag(x), real(y), real(x) * imag(y)) +end @inline _imagconjtimes(x::Real, y::Complex) = x * imag(y) @inline _imagconjtimes(x::Complex, y::Real) = -imag(x) * y @inline _imagconjtimes(x::Real, y::Real) = Zero() From 34419ce3cdf09425a458ef1485f573c1261f9c94 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 30 Jun 2020 01:44:25 -0700 Subject: [PATCH 29/29] Update src/rulesets/Base/fastmath_able.jl Co-authored-by: willtebbutt --- src/rulesets/Base/fastmath_able.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 8c8f1b7ef..7f7562dc6 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -111,11 +111,11 @@ let # `hypot` function frule( - (_, Δx, Δy), - ::typeof(hypot), - x::T, - y::T, - ) where {T<:Union{Real,Complex}} + (_, Δx, Δy), + ::typeof(hypot), + x::T, + y::T, + ) where {T<:Union{Real,Complex}} Ω = hypot(x, y) n = ifelse(iszero(Ω), one(Ω), Ω) ∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n