diff --git a/src/derivative.jl b/src/derivative.jl index a60d8d0e..5d7e5dab 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -70,6 +70,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba end derivative(f, x::AbstractArray) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number. Perhaps you meant gradient(f, x)?")) +derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input.")) ##################### # result extraction # @@ -78,9 +79,13 @@ derivative(f, x::AbstractArray) = throw(DimensionMismatch("derivative(f, x) expe # non-mutating # #--------------# -@inline extract_derivative(::Type{T}, y::Dual) where {T} = partials(T, y, 1) @inline extract_derivative(::Type{T}, y::Real) where {T} = zero(y) +@inline extract_derivative(::Type{T}, y::Complex) where {T} = zero(y) +@inline extract_derivative(::Type{T}, y::Dual) where {T} = partials(T, y, 1) @inline extract_derivative(::Type{T}, y::AbstractArray) where {T} = map(d -> extract_derivative(T,d), y) +@inline function extract_derivative(::Type{T}, y::Complex{TD}) where {T, TD <: Dual} + complex(partials(T, real(y), 1), partials(T, imag(y), 1)) +end # mutating # #----------# diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index f5645a7e..dfdd8ed2 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -100,4 +100,8 @@ end @test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3)) end +@testset "complex output" begin + @test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im) +end + end # module