Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
90a369d
Fix indentation
sethaxen Jun 28, 2020
a49187b
Test \ on complex inputs
sethaxen Jun 28, 2020
aa26f04
Test ^ on complex inputs
sethaxen Jun 28, 2020
e47afd2
Test identity on complex inputs
sethaxen Jun 28, 2020
beaae0c
Test muladd on complex inputs
sethaxen Jun 28, 2020
b89b588
Test binary functions on complex inputs
sethaxen Jun 28, 2020
4738c48
Test functions on complex inputs
sethaxen Jun 28, 2020
0f34193
Release type constraint on exp
sethaxen Jun 28, 2020
0dd4023
Add _realconjtimes
sethaxen Jun 28, 2020
d737cdf
Use _realconjtimes in abs/abs2 rules
sethaxen Jun 28, 2020
e5276e6
Add complex rule for hypot
sethaxen Jun 28, 2020
ac58495
Add generic rule for adjoint
sethaxen Jun 28, 2020
7f6c709
Add generic rule for real
sethaxen Jun 28, 2020
b5fef9e
Add generic rule for imag
sethaxen Jun 28, 2020
45ba9b7
Add complex rule for hypot
sethaxen Jun 28, 2020
5971f4f
Add rules/tests for Complex
sethaxen Jun 28, 2020
45b2edc
Test frule for identity
sethaxen Jun 28, 2020
c678197
Add missing angle test
sethaxen Jun 28, 2020
3b3d11d
Make inline just in case
sethaxen Jun 29, 2020
f6f7c8a
Unify abs rules
sethaxen Jun 29, 2020
16f8307
Introduce _imagconjtimes utility function
sethaxen Jun 29, 2020
86dac82
Unify angle rules
sethaxen Jun 29, 2020
071b658
Unify sign rules
sethaxen Jun 29, 2020
3ff0457
Merge branch 'master' into complextests
sethaxen Jun 29, 2020
e0dd0d3
Multiply by correct variable
sethaxen Jun 29, 2020
2ce61d4
Fix argument order
sethaxen Jun 29, 2020
70f6d70
Merge branch 'complextests' of https://github.com/sethaxen/ChainRules…
sethaxen Jun 29, 2020
1c97506
Bump ChainRulesTestUtils version number
sethaxen Jun 29, 2020
c286f41
Restrict to Complex
sethaxen Jun 30, 2020
980c413
Use muladd
sethaxen Jun 30, 2020
34419ce
Update src/rulesets/Base/fastmath_able.jl
sethaxen Jun 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9"
ChainRulesTestUtils = "0.4.1"
ChainRulesTestUtils = "0.4.2"
Compat = "3"
FiniteDifferences = "0.10"
Reexport = "0.2"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

include("rulesets/Base/utils.jl")
include("rulesets/Base/base.jl")
include("rulesets/Base/fastmath_able.jl")
include("rulesets/Base/array.jl")
Expand Down
69 changes: 67 additions & 2 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,80 @@

@scalar_rule one(x) zero(x)
@scalar_rule zero(x) zero(x)
@scalar_rule adjoint(x::Real) One()
@scalar_rule transpose(x) One()

# `adjoint`

frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')

function rrule(::typeof(adjoint), z::Number)
adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ')
return (z', adjoint_pullback)
end

# `real`

@scalar_rule real(x::Real) One()

frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz))

function rrule(::typeof(real), z::Number)
# add zero(z) to embed the real number in the same number type as z
real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z))
return (real(z), real_pullback)
end

# `imag`

@scalar_rule imag(x::Real) Zero()

frule((_, Δz), ::typeof(imag), z::Complex) = (imag(z), imag(Δz))

function rrule(::typeof(imag), z::Complex)
imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im)
return (imag(z), imag_pullback)
end

# `Complex`

frule((_, Δz), ::Type{T}, z::Number) where {T<:Complex} = (T(z), Complex(Δz))
function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex}
return (T(x, y), Complex(Δx, Δy))
end

function rrule(::Type{T}, z::Complex) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ))
return (T(z), Complex_pullback)
end
function rrule(::Type{T}, x::Real) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ))
return (T(x), Complex_pullback)
end
function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ))
return (T(x, y), Complex_pullback)
end

# `hypot`

@scalar_rule hypot(x::Real) sign(x)

function frule((_, Δz), ::typeof(hypot), z::Complex)
Ω = hypot(z)
∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω)
return Ω, ∂Ω
end

function rrule(::typeof(hypot), z::Complex)
Ω = hypot(z)
function hypot_pullback(ΔΩ)
return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z)
end
return (Ω, hypot_pullback)
end

@scalar_rule fma(x, y, z) (y, x, One())
@scalar_rule muladd(x, y, z) (y, x, One())
@scalar_rule real(x::Real) One()
@scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist())
@scalar_rule(
mod(x, y),
Expand Down
113 changes: 51 additions & 62 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ let
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
@scalar_rule inv(x) -(Ω ^ 2)
@scalar_rule sqrt(x) inv(2Ω)
@scalar_rule exp(x::Real) Ω
@scalar_rule exp(x) Ω
@scalar_rule exp10(x) Ω * log(oftype(x, 10))
@scalar_rule exp2(x) Ω * log(oftype(x, 2))
@scalar_rule expm1(x) exp(x)
Expand All @@ -37,37 +37,25 @@ let

# Unary complex functions
## abs
function frule((_, Δx), ::typeof(abs), x::Real)
return abs(x), sign(x) * real(Δx)
end
function frule((_, Δz), ::typeof(abs), z::Complex)
Ω = abs(z)
return Ω, (real(z) * real(Δz) + imag(z) * imag(Δz)) / ifelse(iszero(z), one(Ω), Ω)
function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex})
Ω = abs(x)
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
# `ifelse` is applied only to denominator to ensure type-stability.
return Ω, _realconjtimes(signx, Δx)
end

function rrule(::typeof(abs), x::Real)
function rrule(::typeof(abs), x::Union{Real, Complex})
Ω = abs(x)
function abs_pullback(ΔΩ)
return (NO_FIELDS, real(ΔΩ)*sign(x))
end
return abs(x), abs_pullback
end
function rrule(::typeof(abs), z::Complex)
Ω = abs(z)
function abs_pullback(ΔΩ)
Δu = real(ΔΩ)
return (NO_FIELDS, Δu*z/ifelse(iszero(z), one(Ω), Ω))
# `ifelse` is applied only to denominator to ensure type-stability.
signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω)
return (NO_FIELDS, signx * real(ΔΩ))
end
return Ω, abs_pullback
end

## abs2
function frule((_, Δx), ::typeof(abs2), x::Real)
return abs2(x), 2x * real(Δx)
end
function frule((_, Δz), ::typeof(abs2), z::Complex)
return abs2(z), 2 * (real(z) * real(Δz) + imag(z) * imag(Δz))
function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex})
return abs2(z), 2 * _realconjtimes(z, Δz)
end

function rrule(::typeof(abs2), z::Union{Real, Complex})
Expand All @@ -90,20 +78,13 @@ let
end

## angle
function frule((_, Δz), ::typeof(angle), x::Real)
Δx, Δy = reim(Δz)
return angle(x), Δy/ifelse(iszero(x), one(x), x)
# `ifelse` is applied only to denominator to ensure type-stability.
end
function frule((_, Δz)::Tuple{<:Any, <:Real}, ::typeof(angle), x::Real)
return angle(x), Zero()
end
function frule((_, Δz), ::typeof(angle), z::Complex)
x, y = reim(z)
Δx, Δy = reim(Δz)
return angle(z), (-y*Δx + x*Δy)/ifelse(iszero(z), one(z), abs2(z))
function frule((_, Δx), ::typeof(angle), x)
Ω = angle(x)
# `ifelse` is applied only to denominator to ensure type-stability.
∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x))
return Ω, ∂Ω
end

function rrule(::typeof(angle), x::Real)
function angle_pullback(ΔΩ::Real)
return (NO_FIELDS, Zero())
Expand All @@ -126,7 +107,30 @@ let
end

# Binary functions
@scalar_rule hypot(x::Real, y::Real) (x / Ω, y / Ω)

# `hypot`

function frule(
(_, Δx, Δy),
::typeof(hypot),
x::T,
y::T,
) where {T<:Union{Real,Complex}}
Ω = hypot(x, y)
n = ifelse(iszero(Ω), one(Ω), Ω)
∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n
return Ω, ∂Ω
end

function rrule(::typeof(hypot), x::T, y::T) where {T<:Union{Real,Complex}}
Ω = hypot(x, y)
function hypot_pullback(ΔΩ)
c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)
return (NO_FIELDS, @thunk(c * x), @thunk(c * y))
end
return (Ω, hypot_pullback)
end

@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x / y (inv(y), -((x / y) / y))
Expand All @@ -146,30 +150,21 @@ let
@scalar_rule +x One()
@scalar_rule -x -1

function frule((_, Δx), ::typeof(sign), x::Real)
Ω = sign(x)
∂Ω = _sign_jvp(Ω, x, Δx)
return Ω, ∂Ω
end
function frule((_, Δz), ::typeof(sign), z::Complex)
absz = abs(ifelse(iszero(z), one(z), z))
Ω = z / absz
∂Ω = _sign_jvp(Ω, absz, Δz)
# `sign`

function frule((_, Δx), ::typeof(sign), x)
n = ifelse(iszero(x), one(x), abs(x))
Ω = x isa Real ? sign(x) : x / n
∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im
return Ω, ∂Ω
end

function rrule(::typeof(sign), x::Real)
Ω = sign(x)
function sign_pullback(ΔΩ)
return (NO_FIELDS, _sign_jvp(Ω, x, ΔΩ))
end
return Ω, sign_pullback
end
function rrule(::typeof(sign), z::Complex)
absz = abs(ifelse(iszero(z), one(z), z))
Ω = z / absz
function rrule(::typeof(sign), x)
n = ifelse(iszero(x), one(x), abs(x))
Ω = x isa Real ? sign(x) : x / n
function sign_pullback(ΔΩ)
return (NO_FIELDS, _sign_jvp(Ω, absz, ΔΩ))
∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im
return (NO_FIELDS, ∂x)
end
return Ω, sign_pullback
end
Expand All @@ -196,9 +191,3 @@ let
eval(fastable_ast) # Get original definitions
# we do this second so it overwrites anything we included by mistake in the fastable
end

# the jacobian for `sign` is symmetric; `_sign_jvp` gives both J * Δz and Jᵀ * ΔΩ for
# output Ω, (co)tangent Δ, and real input x or the absolute value of complex input z
_sign_jvp(Ω, absz, Δ) = Ω * ((imag(Δ) * real(Ω) - real(Δ) * imag(Ω)) / absz)im
_sign_jvp(Ω::Real, x::Real, Δ) = (imag(Δ) * Ω / ifelse(iszero(x), one(x), x)) * im
_sign_jvp(Ω::Real, x::Real, Δ::Real) = Zero()
15 changes: 15 additions & 0 deletions src/rulesets/Base/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# real(conj(x) * y) avoiding computing the imaginary part if possible
@inline _realconjtimes(x, y) = real(conj(x) * y)
@inline _realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
@inline _realconjtimes(x::Real, y::Complex) = x * real(y)
@inline _realconjtimes(x::Complex, y::Real) = real(x) * y
@inline _realconjtimes(x::Real, y::Real) = x * y

# imag(conj(x) * y) avoiding computing the real part if possible
@inline _imagconjtimes(x, y) = imag(conj(x) * y)
@inline function _imagconjtimes(x::Complex, y::Complex)
return muladd(-imag(x), real(y), real(x) * imag(y))
end
@inline _imagconjtimes(x::Real, y::Complex) = x * imag(y)
@inline _imagconjtimes(x::Complex, y::Real) = -imag(x) * y
@inline _imagconjtimes(x::Real, y::Real) = Zero()
Loading