From b2d73fe17d08c47f8566f4e11c6cd9a913b408ca Mon Sep 17 00:00:00 2001 From: Antoine Levitt Date: Fri, 29 Apr 2022 15:31:03 +0200 Subject: [PATCH] Support complex output in `derivative` --- src/derivative.jl | 7 ++++++- test/DerivativeTest.jl | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/derivative.jl b/src/derivative.jl index a60d8d0e..5d7e5dab 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -70,6 +70,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba end derivative(f, x::AbstractArray) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number. Perhaps you meant gradient(f, x)?")) +derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input.")) ##################### # result extraction # @@ -78,9 +79,13 @@ derivative(f, x::AbstractArray) = throw(DimensionMismatch("derivative(f, x) expe # non-mutating # #--------------# -@inline extract_derivative(::Type{T}, y::Dual) where {T} = partials(T, y, 1) @inline extract_derivative(::Type{T}, y::Real) where {T} = zero(y) +@inline extract_derivative(::Type{T}, y::Complex) where {T} = zero(y) +@inline extract_derivative(::Type{T}, y::Dual) where {T} = partials(T, y, 1) @inline extract_derivative(::Type{T}, y::AbstractArray) where {T} = map(d -> extract_derivative(T,d), y) +@inline function extract_derivative(::Type{T}, y::Complex{TD}) where {T, TD <: Dual} + complex(partials(T, real(y), 1), partials(T, imag(y), 1)) +end # mutating # #----------# diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index f5645a7e..dfdd8ed2 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -100,4 +100,8 @@ end @test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3)) end +@testset "complex output" begin + @test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im) +end + end # module