diff --git a/Project.toml b/Project.toml index a8289bc0e..a3b2c3d26 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.7.1" +version = "1.8.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/docs/src/api.md b/docs/src/api.md index 9b998c0af..e883c7c61 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -10,7 +10,10 @@ Private = false ## Rule Definition Tools ```@autodocs Modules = [ChainRulesCore] -Pages = ["rule_definition_tools.jl"] +Pages = [ + "rule_definition_tools.jl", + "utils.jl", +] Private = false ``` diff --git a/docs/src/complex.md b/docs/src/complex.md index c92d22712..132d36996 100644 --- a/docs/src/complex.md +++ b/docs/src/complex.md @@ -87,3 +87,6 @@ end There are various notions of complex derivatives (holomorphic and Wirtinger derivatives, Jacobians, gradients, etc.) which differ in subtle but important ways. The goal of ChainRules is to provide the basic differentiation rules upon which these derivatives can be implemented, but it does not implement these derivatives itself. It is recommended that you carefully check how the above definitions of `frule` and `rrule` translate into your specific notion of complex derivative, since getting this wrong will quietly give you wrong results. + +!!! note + If you implement `rrule` for a non-holomorphic function, [`realdot`](@ref) and [`imagdot`](@ref) can be useful. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index dbc0f057a..00d4a111a 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -16,6 +16,8 @@ export add!! # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk +# helpers for rules with complex numbers +export realdot, imagdot include("compat.jl") include("debug_mode.jl") @@ -34,6 +36,7 @@ include("config.jl") include("rules.jl") include("rule_definition_tools.jl") include("ignore_derivatives.jl") +include("complex_math.jl") include("deprecated.jl") diff --git a/src/complex_math.jl b/src/complex_math.jl new file mode 100644 index 000000000..0399120b1 --- /dev/null +++ b/src/complex_math.jl @@ -0,0 +1,34 @@ +""" + realdot(x, y) + +Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible. + +This function can be useful if you implement a `rrule` for a non-holomorphic function +on complex numbers. + +See also: [`imagdot`](@ref) +""" +@inline realdot(x, y) = real(dot(x, y)) +@inline realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline realdot(x::Real, y::Number) = x * real(y) +@inline realdot(x::Number, y::Real) = real(x) * y +@inline realdot(x::Real, y::Real) = x * y + +""" + imagdot(x, y) + +Compute `imag(dot(x, y))` while avoiding computing the real part if possible. + +This function can be useful if you implement a `rrule` for a non-holomorphic function +on complex numbers. + +See also: [`realdot`](@ref) +""" +@inline imagdot(x, y) = imag(dot(x, y)) +@inline function imagdot(x::Number, y::Number) + return muladd(-imag(x), real(y), real(x) * imag(y)) +end +@inline imagdot(x::Real, y::Number) = x * imag(y) +@inline imagdot(x::Number, y::Real) = -imag(x) * y +@inline imagdot(x::Real, y::Real) = ZeroTangent() +@inline imagdot(x::AbstractArray{<:Real}, y::AbstractArray{<:Real}) = ZeroTangent() diff --git a/test/complex_math.jl b/test/complex_math.jl new file mode 100644 index 000000000..6963c0eb1 --- /dev/null +++ b/test/complex_math.jl @@ -0,0 +1,37 @@ +# struct need to be defined outside of tests for julia 1.0 compat +# custom complex number to test fallback definition +struct CustomComplex{T} + re::T + im::T +end + +Base.real(x::CustomComplex) = x.re +Base.imag(x::CustomComplex) = x.im + +function LinearAlgebra.dot(a::CustomComplex, b::Number) + return CustomComplex(reim((a.re - a.im * im) * b)...) +end +function LinearAlgebra.dot(a::Number, b::CustomComplex) + return CustomComplex(reim(conj(a) * (b.re + b.im * im))...) +end +function LinearAlgebra.dot(a::CustomComplex, b::CustomComplex) + return CustomComplex(reim((a.re - a.im * im) * (b.re + b.im * im))...) +end + +@testset "complex_math.jl" begin + @testset "dot" begin + scalars = (randn(), randn(ComplexF64), CustomComplex(reim(randn(ComplexF64))...)) + arrays = (randn(10), randn(ComplexF64, 10)) + for inputs in (scalars, arrays) + for x in inputs, y in inputs + @test realdot(x, y) == real(dot(x, y)) + + if eltype(x) <: Real && eltype(y) <: Real + @test imagdot(x, y) === ZeroTangent() + else + @test imagdot(x, y) == imag(dot(x, y)) + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6a4684d03..6a49a093b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,7 @@ using Test include("rule_definition_tools.jl") include("config.jl") include("ignore_derivatives.jl") + include("complex_math.jl") include("deprecated.jl") end