diff --git a/Project.toml b/Project.toml index b0849033c..b9a644db7 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.9" -ChainRulesTestUtils = "0.4.1" +ChainRulesTestUtils = "0.4.2" Compat = "3" FiniteDifferences = "0.10" Reexport = "0.2" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 4af705dc3..392080e5d 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -21,6 +21,7 @@ if VERSION < v"1.3.0-DEV.142" import LinearAlgebra: dot end +include("rulesets/Base/utils.jl") include("rulesets/Base/base.jl") include("rulesets/Base/fastmath_able.jl") include("rulesets/Base/array.jl") diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 7f474e0c5..c70af02c8 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -3,15 +3,80 @@ @scalar_rule one(x) zero(x) @scalar_rule zero(x) zero(x) -@scalar_rule adjoint(x::Real) One() @scalar_rule transpose(x) One() + +# `adjoint` + +frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz') + +function rrule(::typeof(adjoint), z::Number) + adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ') + return (z', adjoint_pullback) +end + +# `real` + +@scalar_rule real(x::Real) One() + +frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz)) + +function rrule(::typeof(real), z::Number) + # add zero(z) to embed the real number in the same number type as z + real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z)) + return (real(z), real_pullback) +end + +# `imag` + @scalar_rule imag(x::Real) Zero() + +frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz)) + +function rrule(::typeof(imag), z::Complex) + imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im) + return (imag(z), imag_pullback) +end + +# `Complex` + +frule((_, Δz), ::Type{T}, z::Number) where {T<:Complex} = (T(z), Complex(Δz)) +function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex} + return (T(x, y), Complex(Δx, Δy)) +end + +function rrule(::Type{T}, z::Complex) where {T<:Complex} + Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ)) + return (T(z), Complex_pullback) +end +function rrule(::Type{T}, x::Real) where {T<:Complex} + Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ)) + return (T(x), Complex_pullback) +end +function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} + Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ)) + return (T(x, y), Complex_pullback) +end + +# `hypot` + @scalar_rule hypot(x::Real) sign(x) +function frule((_, Δz), ::typeof(hypot), z::Complex) + Ω = hypot(z) + ∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω) + return Ω, ∂Ω +end + +function rrule(::typeof(hypot), z::Complex) + Ω = hypot(z) + function hypot_pullback(ΔΩ) + return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) + end + return (Ω, hypot_pullback) +end @scalar_rule fma(x, y, z) (y, x, One()) @scalar_rule muladd(x, y, z) (y, x, One()) -@scalar_rule real(x::Real) One() @scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist()) @scalar_rule( mod(x, y), diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 559b64726..7f7562dc6 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -25,7 +25,7 @@ let @scalar_rule cbrt(x) inv(3 * Ω ^ 2) @scalar_rule inv(x) -(Ω ^ 2) @scalar_rule sqrt(x) inv(2Ω) - @scalar_rule exp(x::Real) Ω + @scalar_rule exp(x) Ω @scalar_rule exp10(x) Ω * log(oftype(x, 10)) @scalar_rule exp2(x) Ω * log(oftype(x, 2)) @scalar_rule expm1(x) exp(x) @@ -37,37 +37,25 @@ let # Unary complex functions ## abs - function frule((_, Δx), ::typeof(abs), x::Real) - return abs(x), sign(x) * real(Δx) - end - function frule((_, Δz), ::typeof(abs), z::Complex) - Ω = abs(z) - return Ω, (real(z) * real(Δz) + imag(z) * imag(Δz)) / ifelse(iszero(z), one(Ω), Ω) + function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex}) + Ω = abs(x) + signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) # `ifelse` is applied only to denominator to ensure type-stability. + return Ω, _realconjtimes(signx, Δx) end - function rrule(::typeof(abs), x::Real) + function rrule(::typeof(abs), x::Union{Real, Complex}) + Ω = abs(x) function abs_pullback(ΔΩ) - return (NO_FIELDS, real(ΔΩ)*sign(x)) - end - return abs(x), abs_pullback - end - function rrule(::typeof(abs), z::Complex) - Ω = abs(z) - function abs_pullback(ΔΩ) - Δu = real(ΔΩ) - return (NO_FIELDS, Δu*z/ifelse(iszero(z), one(Ω), Ω)) - # `ifelse` is applied only to denominator to ensure type-stability. + signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) + return (NO_FIELDS, signx * real(ΔΩ)) end return Ω, abs_pullback end ## abs2 - function frule((_, Δx), ::typeof(abs2), x::Real) - return abs2(x), 2x * real(Δx) - end - function frule((_, Δz), ::typeof(abs2), z::Complex) - return abs2(z), 2 * (real(z) * real(Δz) + imag(z) * imag(Δz)) + function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex}) + return abs2(z), 2 * _realconjtimes(z, Δz) end function rrule(::typeof(abs2), z::Union{Real, Complex}) @@ -90,20 +78,13 @@ let end ## angle - function frule((_, Δz), ::typeof(angle), x::Real) - Δx, Δy = reim(Δz) - return angle(x), Δy/ifelse(iszero(x), one(x), x) - # `ifelse` is applied only to denominator to ensure type-stability. - end - function frule((_, Δz)::Tuple{<:Any, <:Real}, ::typeof(angle), x::Real) - return angle(x), Zero() - end - function frule((_, Δz), ::typeof(angle), z::Complex) - x, y = reim(z) - Δx, Δy = reim(Δz) - return angle(z), (-y*Δx + x*Δy)/ifelse(iszero(z), one(z), abs2(z)) + function frule((_, Δx), ::typeof(angle), x) + Ω = angle(x) # `ifelse` is applied only to denominator to ensure type-stability. + ∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x)) + return Ω, ∂Ω end + function rrule(::typeof(angle), x::Real) function angle_pullback(ΔΩ::Real) return (NO_FIELDS, Zero()) @@ -126,7 +107,30 @@ let end # Binary functions - @scalar_rule hypot(x::Real, y::Real) (x / Ω, y / Ω) + + # `hypot` + + function frule( + (_, Δx, Δy), + ::typeof(hypot), + x::T, + y::T, + ) where {T<:Union{Real,Complex}} + Ω = hypot(x, y) + n = ifelse(iszero(Ω), one(Ω), Ω) + ∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n + return Ω, ∂Ω + end + + function rrule(::typeof(hypot), x::T, y::T) where {T<:Union{Real,Complex}} + Ω = hypot(x, y) + function hypot_pullback(ΔΩ) + c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) + return (NO_FIELDS, @thunk(c * x), @thunk(c * y)) + end + return (Ω, hypot_pullback) + end + @scalar_rule x + y (One(), One()) @scalar_rule x - y (One(), -1) @scalar_rule x / y (inv(y), -((x / y) / y)) @@ -146,30 +150,21 @@ let @scalar_rule +x One() @scalar_rule -x -1 - function frule((_, Δx), ::typeof(sign), x::Real) - Ω = sign(x) - ∂Ω = _sign_jvp(Ω, x, Δx) - return Ω, ∂Ω - end - function frule((_, Δz), ::typeof(sign), z::Complex) - absz = abs(ifelse(iszero(z), one(z), z)) - Ω = z / absz - ∂Ω = _sign_jvp(Ω, absz, Δz) + # `sign` + + function frule((_, Δx), ::typeof(sign), x) + n = ifelse(iszero(x), one(x), abs(x)) + Ω = x isa Real ? sign(x) : x / n + ∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im return Ω, ∂Ω end - function rrule(::typeof(sign), x::Real) - Ω = sign(x) - function sign_pullback(ΔΩ) - return (NO_FIELDS, _sign_jvp(Ω, x, ΔΩ)) - end - return Ω, sign_pullback - end - function rrule(::typeof(sign), z::Complex) - absz = abs(ifelse(iszero(z), one(z), z)) - Ω = z / absz + function rrule(::typeof(sign), x) + n = ifelse(iszero(x), one(x), abs(x)) + Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) - return (NO_FIELDS, _sign_jvp(Ω, absz, ΔΩ)) + ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im + return (NO_FIELDS, ∂x) end return Ω, sign_pullback end @@ -196,9 +191,3 @@ let eval(fastable_ast) # Get original definitions # we do this second so it overwrites anything we included by mistake in the fastable end - -# the jacobian for `sign` is symmetric; `_sign_jvp` gives both J * Δz and Jᵀ * ΔΩ for -# output Ω, (co)tangent Δ, and real input x or the absolute value of complex input z -_sign_jvp(Ω, absz, Δ) = Ω * ((imag(Δ) * real(Ω) - real(Δ) * imag(Ω)) / absz)im -_sign_jvp(Ω::Real, x::Real, Δ) = (imag(Δ) * Ω / ifelse(iszero(x), one(x), x)) * im -_sign_jvp(Ω::Real, x::Real, Δ::Real) = Zero() diff --git a/src/rulesets/Base/utils.jl b/src/rulesets/Base/utils.jl new file mode 100644 index 000000000..d957d555f --- /dev/null +++ b/src/rulesets/Base/utils.jl @@ -0,0 +1,15 @@ +# real(conj(x) * y) avoiding computing the imaginary part if possible +@inline _realconjtimes(x, y) = real(conj(x) * y) +@inline _realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline _realconjtimes(x::Real, y::Complex) = x * real(y) +@inline _realconjtimes(x::Complex, y::Real) = real(x) * y +@inline _realconjtimes(x::Real, y::Real) = x * y + +# imag(conj(x) * y) avoiding computing the real part if possible +@inline _imagconjtimes(x, y) = imag(conj(x) * y) +@inline function _imagconjtimes(x::Complex, y::Complex) + return muladd(-imag(x), real(y), real(x) * imag(y)) +end +@inline _imagconjtimes(x::Real, y::Complex) = x * imag(y) +@inline _imagconjtimes(x::Complex, y::Real) = -imag(x) * y +@inline _imagconjtimes(x::Real, y::Real) = Zero() diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 9e924b138..e2c621a9f 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -46,14 +46,14 @@ end # Trig @testset "Angles" begin - for x in (-0.1, 6.4) + for x in (-0.1, 6.4, 0.5 + 0.25im) test_scalar(deg2rad, x) test_scalar(rad2deg, x) end end @testset "Unary complex functions" begin - for x in (-4.1, 6.4) + for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) test_scalar(real, x) test_scalar(imag, x) test_scalar(hypot, x) @@ -61,6 +61,16 @@ end end + @testset "Complex" begin + test_scalar(Complex, randn()) + test_scalar(Complex, randn(ComplexF64)) + x, ẋ, x̄ = randn(3) + y, ẏ, ȳ = randn(3) + Δz = randn(ComplexF64) + frule_test(Complex, (x, ẋ), (y, ẏ)) + rrule_test(Complex, Δz, (x, x̄), (y, ȳ)) + end + @testset "*(x, y) (scalar)" begin # This is pretty important so testing it fairly heavily test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im) @@ -80,44 +90,57 @@ end end - @testset "ldexp" begin - x, Δx, x̄ = 10rand(3) - Δz = rand() + @testset "ldexp" begin + x, Δx, x̄ = 10rand(3) + Δz = rand() + + for n in (0,1,20) + # TODO: Forward test does not work when parameter is Integer + # See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/22 + #frule_test(ldexp, (x, Δx), (n, nothing)) + rrule_test(ldexp, Δz, (x, x̄), (n, nothing)) + end + end - for n in (0,1,20) - # TODO: Forward test does not work when parameter is Integer - # See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/22 - #frule_test(ldexp, (x, Δx), (n, nothing)) - rrule_test(ldexp, Δz, (x, x̄), (n, nothing)) - end - end + @testset "\\(x::$T, y::$T) (scalar)" for T in (Float64, ComplexF64) + x, ẋ, x̄, y, ẏ, ȳ, Δz = randn(T, 7) + frule_test(*, (x, ẋ), (y, ẏ)) + rrule_test(*, Δz, (x, x̄), (y, ȳ)) + end - @testset "binary function ($f)" for f in (mod, \) + @testset "mod" begin x, Δx, x̄ = 10rand(3) y, Δy, ȳ = rand(3) Δz = rand() - frule_test(f, (x, Δx), (y, Δy)) - rrule_test(f, Δz, (x, x̄), (y, ȳ)) + frule_test(mod, (x, Δx), (y, Δy)) + rrule_test(mod, Δz, (x, x̄), (y, ȳ)) end - @testset "x^n for x<0" begin - x = -15*rand() - Δx, x̄ = 10rand(2) - y, Δy, ȳ = rand(3) - Δz = rand() + @testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64) + # for real x and n, x must be >0 + x = T <: Real ? 15rand() : 15randn(ComplexF64) + Δx, x̄ = 10rand(T, 2) + y, Δy, ȳ = rand(T, 3) + Δz = rand(T) - frule_test(^, (-x, Δx), (y, Δy)) - rrule_test(^, Δz, (-x, x̄), (y, ȳ)) + frule_test(^, (x, Δx), (y, Δy)) + rrule_test(^, Δz, (x, x̄), (y, ȳ)) end - @testset "identity" begin - rrule_test(identity, randn(), (randn(), randn())) - rrule_test(identity, randn(4), (randn(4), randn(4))) + @testset "identity" for T in (Float64, ComplexF64) + frule_test(identity, (randn(T), randn(T))) + frule_test(identity, (randn(T, 4), randn(T, 4))) + frule_test( + identity, + (Composite{Tuple}(randn(T, 3)...), Composite{Tuple}(randn(T, 3)...)) + ) + rrule_test(identity, randn(T), (randn(T), randn(T))) + rrule_test(identity, randn(T, 4), (randn(T, 4), randn(T, 4))) rrule_test( - identity, Tuple(randn(3)), - (Composite{Tuple}(randn(3)...), Composite{Tuple}(randn(3)...)) + identity, Tuple(randn(T, 3)), + (Composite{Tuple}(randn(T, 3)...), Composite{Tuple}(randn(T, 3)...)) ) end @@ -126,15 +149,26 @@ test_scalar(zero, x) end - @testset "trinary ($f)" for f in (muladd, fma) + @testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64) + x, Δx, x̄ = 10randn(T, 3) + y, Δy, ȳ = randn(T, 3) + z, Δz, z̄ = randn(T, 3) + Δk = randn(T) + + frule_test(muladd, (x, Δx), (y, Δy), (z, Δz)) + rrule_test(muladd, Δk, (x, x̄), (y, ȳ), (z, z̄)) + end + + @testset "fma" begin x, Δx, x̄ = 10randn(3) y, Δy, ȳ = randn(3) z, Δz, z̄ = randn(3) Δk = randn() - frule_test(f, (x, Δx), (y, Δy), (z, Δz)) - rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄)) + frule_test(fma, (x, Δx), (y, Δy), (z, Δz)) + rrule_test(fma, Δk, (x, x̄), (y, ȳ), (z, z̄)) end + @testset "clamp" begin x̄, ȳ, z̄ = randn(3) Δx, Δy, Δz = randn(3) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 42bbec3b7..36dd4c2ea 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -55,9 +55,9 @@ const FASTABLE_AST = quote test_scalar(atan, x) end @testset "Multivariate" begin - @testset "sincos" begin - x, Δx, x̄ = randn(3) - Δz = (randn(), randn()) + @testset "sincos(x::$T)" for T in (Float64, ComplexF64) + x, Δx, x̄ = randn(T, 3) + Δz = (randn(T), randn(T)) frule_test(sincos, (x, Δx)) rrule_test(sincos, Δz, (x, x̄)) @@ -66,7 +66,7 @@ const FASTABLE_AST = quote end @testset "exponents" begin - for x in (-0.1, 6.4) + for x in (-0.1, 6.4, 0.5 + 0.25im) test_scalar(inv, x) test_scalar(exp, x) @@ -74,9 +74,11 @@ const FASTABLE_AST = quote test_scalar(exp10, x) test_scalar(expm1, x) - test_scalar(cbrt, x) + if x isa Real + test_scalar(cbrt, x) + end - if x >= 0 + if x isa Complex || x >= 0 test_scalar(sqrt, x) test_scalar(log, x) test_scalar(log2, x) @@ -100,22 +102,43 @@ const FASTABLE_AST = quote end @test frule((Zero(), randn()), angle, randn())[2] === Zero() @test rrule(angle, randn())[2](randn())[2] === Zero() + + # test that real primal with complex tangent gives complex tangent + ΔΩ = randn(ComplexF64) + for x in (-0.5, 2.0) + @test isapprox( + frule((Zero(), ΔΩ), angle, x)[2], + frule((Zero(), ΔΩ), angle, complex(x))[2], + ) + end end @testset "Unary functions" begin - for x in (-4.1, 6.4) + for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) test_scalar(+, x) test_scalar(-, x) + test_scalar(atan, x) end end - @testset "binary function ($f)" for f in (/, +, -, hypot, atan, rem, ^, max, min) - x, Δx, x̄ = 10rand(3) - y, Δy, ȳ = rand(3) - Δz = rand() + @testset "binary functions" begin + @testset "$f(x, y)" for f in (atan, rem, max, min) + x, Δx, x̄ = 10rand(3) + y, Δy, ȳ = rand(3) + Δz = rand() - frule_test(f, (x, Δx), (y, Δy)) - rrule_test(f, Δz, (x, x̄), (y, ȳ)) + frule_test(f, (x, Δx), (y, Δy)) + rrule_test(f, Δz, (x, x̄), (y, ȳ)) + end + + @testset "$f(x::$T, y::$T)" for f in (/, +, -, hypot), T in (Float64, ComplexF64) + x, Δx, x̄ = 10rand(T, 3) + y, Δy, ȳ = rand(T, 3) + Δz = randn(typeof(f(x, y))) + + frule_test(f, (x, Δx), (y, Δy)) + rrule_test(f, Δz, (x, x̄), (y, ȳ)) + end end @testset "sign" begin