diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f83be7ce..620d6aee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: version: - - '1.0' + - '1.6' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index ea59c5fa..3eea1da7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,9 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.33" +version = "0.11" [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" @@ -16,24 +17,27 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +AbstractFFTs = "0.5, 1" Calculus = "0.2, 0.3, 0.4, 0.5" CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" DiffRules = "1.4.0" DiffTests = "0.0.1, 0.1" +FFTW = "1" LogExpFunctions = "0.3" NaNMath = "0.2.2, 0.3, 1" Preferences = "1" SpecialFunctions = "0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0" -julia = "1" +julia = "1.6" [extras] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "FFTW", "SparseArrays", "Test", "InteractiveUtils"] diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index 93d3b246..4b922d23 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -25,6 +25,9 @@ include("gradient.jl") include("jacobian.jl") include("hessian.jl") +import AbstractFFTs +include("fft.jl") + export DiffResults end # module diff --git a/src/dual.jl b/src/dual.jl index 7e86e9b2..40c4e688 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -128,6 +128,8 @@ end @inline tagtype(::Type{V}) where {V} = Nothing @inline tagtype(::Dual{T,V,N}) where {T,V,N} = T @inline tagtype(::Type{Dual{T,V,N}}) where {T,V,N} = T +@inline tagtype(::Complex{T}) where T = tagtype(T) +@inline tagtype(::Type{Complex{T}}) where T = tagtype(T) #################################### # N-ary Operation Definition Tools # diff --git a/src/fft.jl b/src/fft.jl new file mode 100644 index 00000000..8df9bbb9 --- /dev/null +++ b/src/fft.jl @@ -0,0 +1,84 @@ + +value(x::Complex{<:Dual}) = + Complex(x.re.value, x.im.value) + +partials(x::Complex{<:Dual}, n::Int) = + Complex(partials(x.re, n), partials(x.im, n)) + +npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N +npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N + +# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im) +AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) +AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im + +AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x) +AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + +for plan in [:plan_fft, :plan_ifft, :plan_bfft] + @eval begin + + AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), region) + + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), region) + + end +end + +# rfft only accepts real arrays +AbstractFFTs.plan_rfft(x::AbstractArray{<:Dual}, region=1:ndims(x)) = + AbstractFFTs.plan_rfft(value.(x), region) + +for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex? + @eval begin + + AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), region) + + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), d, region) + + end +end + +# for f in (:dct, :idct) +# pf = Symbol("plan_", f) +# @eval begin +# AbstractFFTs.$f(x::AbstractArray{<:Dual}) = $pf(x) * x +# AbstractFFTs.$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x +# AbstractFFTs.$pf(x::AbstractArray{<:Dual}, region; kws...) = $pf(value.(x), region; kws...) +# AbstractFFTs.$pf(x::AbstractArray{<:Complex}, region; kws...) = $pf(value.(x), region; kws...) +# end +# end + + +for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities + @eval begin + + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Dual}) = + _apply_plan(p, x) + + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:Dual}}) = + _apply_plan(p, x) + + end +end + +function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray) + xtil = p * value.(x) + dxtils = ntuple(npartials(eltype(x))) do n + p * partials.(x, n) + end + __apply_plan(tagtype(eltype(x)), xtil, dxtils) +end + +function __apply_plan(T, xtil, dxtils) + map(xtil, dxtils...) do val, parts... + Complex( + Dual{T}(real(val), map(real, parts)), + Dual{T}(imag(val), map(imag, parts)), + ) + end +end \ No newline at end of file diff --git a/test/DualTest.jl b/test/DualTest.jl index 0ec2beec..4dabdc9a 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -15,6 +15,7 @@ struct TestTag end struct OuterTestTag end samerng() = MersenneTwister(1) +Random.seed!(132) # By lower-bounding the Int range at 2, we avoid cases where differentiating an # exponentiation of an Int value would cause a DomainError due to reducing the diff --git a/test/FFTTest.jl b/test/FFTTest.jl new file mode 100644 index 00000000..f3195eb1 --- /dev/null +++ b/test/FFTTest.jl @@ -0,0 +1,32 @@ +module FFTTest + +using Test +using ForwardDiff: Dual, valtype, value, partials, derivative +using FFTW +using AbstractFFTs: complexfloat, realfloat + + +x1 = Dual.(1:4.0, 2:5, 3:6) + +@test value.(x1) == 1:4 +@test partials.(x1, 1) == 2:5 + +@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im +@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + +@test fft(x1, 1)[1] isa Complex{<:Dual} + +@testset "$f" for f in [fft, ifft, rfft, bfft] + @test value.(f(x1)) == f(value.(x1)) + @test partials.(f(x1), 1) == f(partials.(x1, 1)) +end + +f = x -> real(fft([x; 0; 0])[1]) +@test derivative(f,0.1) ≈ 1 + +r = x -> real(rfft([x; 0; 0])[1]) +@test derivative(r,0.1) ≈ 1 + +# c = x -> dct([x; 0; 0])[1] +# @test derivative(c,0.1) ≈ 1 +end # module diff --git a/test/runtests.jl b/test/runtests.jl index a1ac67b3..6f62ce8f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,4 +54,5 @@ Random.seed!(SEED) end end println("##### Running all ForwardDiff tests took $(time() - t0) seconds.") -end \ No newline at end of file +end +