diff --git a/Project.toml b/Project.toml index 8dc2bcf3..1c9d5d68 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.18" [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" @@ -14,11 +15,13 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +AbstractFFTs = "0.5, 1" Calculus = "0.2, 0.3, 0.4, 0.5" CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" DiffRules = "0.0.4, 0.0.5, 0.0.6, 0.0.7, 0.0.8, 0.0.9, 0.0.10, 0.1, 1.0" DiffTests = "0.0.1, 0.1" +FFTW = "1.2.4" NaNMath = "0.2.2, 0.3" SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0" StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0" @@ -27,9 +30,10 @@ julia = "1" [extras] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "FFTW", "SparseArrays", "Test", "InteractiveUtils"] diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index 51ce95a1..ca5e2b8c 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -21,6 +21,9 @@ include("gradient.jl") include("jacobian.jl") include("hessian.jl") +import AbstractFFTs +include("fft.jl") + export DiffResults end # module diff --git a/src/fft.jl b/src/fft.jl new file mode 100644 index 00000000..09ba9d46 --- /dev/null +++ b/src/fft.jl @@ -0,0 +1,70 @@ + +ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = + Complex(x.re.value, x.im.value) + +ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) = + Complex(ForwardDiff.partials(x.re, n), ForwardDiff.partials(x.im, n)) + +ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N +ForwardDiff.npartials(::Type{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N} = N + +# AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im) +AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.complexfloat.(x) +AbstractFFTs.complexfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d) + 0im + +AbstractFFTs.realfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.realfloat.(x) +AbstractFFTs.realfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d) + +for plan in [:plan_fft, :plan_ifft, :plan_bfft] + @eval begin + + AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region) + + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x), region) + + end +end + +# rfft only accepts real arrays +AbstractFFTs.plan_rfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = + AbstractFFTs.plan_rfft(ForwardDiff.value.(x), region) + +for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex? + @eval begin + + AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region) + + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, d::Integer, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x), d, region) + + end +end + +for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities + @eval begin + + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:ForwardDiff.Dual}) = + _apply_plan(p, x) + + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}) = + _apply_plan(p, x) + + end +end + +function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray) + xtil = p * ForwardDiff.value.(x) + dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n + p * ForwardDiff.partials.(x, n) + end + map(xtil, dxtils...) do val, parts... + Complex( + ForwardDiff.Dual(real(val), map(real, parts)), + ForwardDiff.Dual(imag(val), map(imag, parts)), + ) + end +end + diff --git a/test/FFTTest.jl b/test/FFTTest.jl new file mode 100644 index 00000000..aaeb1d8c --- /dev/null +++ b/test/FFTTest.jl @@ -0,0 +1,26 @@ +module FFTTest + +using Test +using ForwardDiff: Dual, valtype, value, partials +using FFTW +using AbstractFFTs: complexfloat, realfloat + + +x1 = Dual.(1:4.0, 2:5, 3:6) + +@test value.(x1) == 1:4 +@test partials.(x1, 1) == 2:5 + +@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im +@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + +@test fft(x1, 1)[1] isa Complex{<:Dual} + +@testset "$f" for f in [fft, ifft, rfft, bfft] + @test value.(f(x1)) == f(value.(x1)) + @test partials.(f(x1), 1) == f(partials.(x1, 1)) +end + + + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index 0b9b1d8b..87bebc04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,9 @@ using ForwardDiff +println("Testing FFT...") +t = @elapsed include("FFTTest.jl") +println("done (took $t seconds).") + println("Testing Partials...") t = @elapsed include("PartialsTest.jl") println("done (took $t seconds).")