From 2c13f22208a9474149e48562c5c056b08567bbca Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 14 Jan 2021 23:29:00 +0100 Subject: [PATCH 1/7] FFT support, take 1 --- Project.toml | 6 +++- src/ForwardDiff.jl | 3 ++ src/fft.jl | 70 ++++++++++++++++++++++++++++++++++++++++++++++ test/FFTTest.jl | 26 +++++++++++++++++ test/runtests.jl | 4 +++ 5 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 src/fft.jl create mode 100644 test/FFTTest.jl diff --git a/Project.toml b/Project.toml index 9c380865..26f00b65 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.14" [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" @@ -12,11 +13,13 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +AbstractFFTs = "0.5, 1" Calculus = "0.2, 0.3, 0.4, 0.5" CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" DiffRules = "0.0.4, 0.0.5, 0.0.6, 0.0.7, 0.0.8, 0.0.9, 0.0.10, 0.1, 1.0" DiffTests = "0.0.1, 0.1" +FFTW = "1.2.4" NaNMath = "0.2.2, 0.3" SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0" StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0" @@ -25,10 +28,11 @@ julia = "1" [extras] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "LinearAlgebra", "SparseArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "FFTW", "LinearAlgebra", "SparseArrays", "Test", "InteractiveUtils"] diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index cfafe6a5..5d233ea8 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -19,6 +19,9 @@ include("gradient.jl") include("jacobian.jl") include("hessian.jl") +import AbstractFFTs +include("fft.jl") + export DiffResults end # module diff --git a/src/fft.jl b/src/fft.jl new file mode 100644 index 00000000..09ba9d46 --- /dev/null +++ b/src/fft.jl @@ -0,0 +1,70 @@ + +ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = + Complex(x.re.value, x.im.value) + +ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) = + Complex(ForwardDiff.partials(x.re, n), ForwardDiff.partials(x.im, n)) + +ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N +ForwardDiff.npartials(::Type{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N} = N + +# AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im) +AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.complexfloat.(x) +AbstractFFTs.complexfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d) + 0im + +AbstractFFTs.realfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.realfloat.(x) +AbstractFFTs.realfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d) + +for plan in [:plan_fft, :plan_ifft, :plan_bfft] + @eval begin + + AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region) + + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x), region) + + end +end + +# rfft only accepts real arrays +AbstractFFTs.plan_rfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = + AbstractFFTs.plan_rfft(ForwardDiff.value.(x), region) + +for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex? + @eval begin + + AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region) + + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, d::Integer, region=1:ndims(x)) = + AbstractFFTs.$plan(ForwardDiff.value.(x), d, region) + + end +end + +for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities + @eval begin + + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:ForwardDiff.Dual}) = + _apply_plan(p, x) + + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}) = + _apply_plan(p, x) + + end +end + +function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray) + xtil = p * ForwardDiff.value.(x) + dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n + p * ForwardDiff.partials.(x, n) + end + map(xtil, dxtils...) do val, parts... + Complex( + ForwardDiff.Dual(real(val), map(real, parts)), + ForwardDiff.Dual(imag(val), map(imag, parts)), + ) + end +end + diff --git a/test/FFTTest.jl b/test/FFTTest.jl new file mode 100644 index 00000000..22ac2ce6 --- /dev/null +++ b/test/FFTTest.jl @@ -0,0 +1,26 @@ +module FFTTest + +using Test +using ForwardDiff: Dual, valtype, value, partials +using FFTW +using AbstractFFTs: complexfloat, realfloat + + +x1 = Dual.(1:4.0, 2:5, 3:6) + +@test value.(x1) == 1:4 +@test partials.(x1, 1) == 2:5 + +@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im +@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + +@test fft(x1, 1)[1] isa Complex{<:Dual} + +@testset "$f" for f in [fft, ifft, rfft, bfft] + @test value.(fft(x1)) == fft(value.(x1)) + @test partials.(fft(x1), 1) == fft(partials.(x1, 1)) +end + + + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index 0b9b1d8b..87bebc04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,9 @@ using ForwardDiff +println("Testing FFT...") +t = @elapsed include("FFTTest.jl") +println("done (took $t seconds).") + println("Testing Partials...") t = @elapsed include("PartialsTest.jl") println("done (took $t seconds).") From 90839a9ef4e3a5e66ee60b7248873bda20d15629 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Jun 2021 05:46:25 -0400 Subject: [PATCH 2/7] Update test/FFTTest.jl Co-authored-by: Niklas Schmitz --- test/FFTTest.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/FFTTest.jl b/test/FFTTest.jl index 22ac2ce6..aaeb1d8c 100644 --- a/test/FFTTest.jl +++ b/test/FFTTest.jl @@ -17,8 +17,8 @@ x1 = Dual.(1:4.0, 2:5, 3:6) @test fft(x1, 1)[1] isa Complex{<:Dual} @testset "$f" for f in [fft, ifft, rfft, bfft] - @test value.(fft(x1)) == fft(value.(x1)) - @test partials.(fft(x1), 1) == fft(partials.(x1, 1)) + @test value.(f(x1)) == f(value.(x1)) + @test partials.(f(x1), 1) == f(partials.(x1, 1)) end From 9be122c4f366ae92eee6fedade491f1f60af7d55 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Jun 2021 09:21:19 -0400 Subject: [PATCH 3/7] bad merge? --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3e7de082..1c9d5d68 100644 --- a/Project.toml +++ b/Project.toml @@ -36,4 +36,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "FFTW", "LinearAlgebra", "SparseArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "FFTW", "SparseArrays", "Test", "InteractiveUtils"] From 5c991eb68a31587acab27d244612ec4482f143c0 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 3 Aug 2021 11:32:54 +0100 Subject: [PATCH 4/7] Fix tagging in _apply_plan --- src/dual.jl | 2 ++ src/fft.jl | 72 +++++++++++++++++++++++++++++-------------------- test/FFTTest.jl | 8 +++++- 3 files changed, 52 insertions(+), 30 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 947c50b5..25d97a49 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -123,6 +123,8 @@ end @inline tagtype(::Type{V}) where {V} = Nothing @inline tagtype(::Dual{T,V,N}) where {T,V,N} = T @inline tagtype(::Type{Dual{T,V,N}}) where {T,V,N} = T +@inline tagtype(::Complex{T}) where T = tagtype(T) +@inline tagtype(::Type{Complex{T}}) where T = tagtype(T) #################################### # N-ary Operation Definition Tools # diff --git a/src/fft.jl b/src/fft.jl index 09ba9d46..8df9bbb9 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1,70 +1,84 @@ -ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = +value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value) -ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) = - Complex(ForwardDiff.partials(x.re, n), ForwardDiff.partials(x.im, n)) +partials(x::Complex{<:Dual}, n::Int) = + Complex(partials(x.re, n), partials(x.im, n)) -ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N -ForwardDiff.npartials(::Type{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N} = N +npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N +npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N -# AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im) -AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.complexfloat.(x) -AbstractFFTs.complexfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d) + 0im +# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im) +AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) +AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im -AbstractFFTs.realfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.realfloat.(x) -AbstractFFTs.realfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d) +AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x) +AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) for plan in [:plan_fft, :plan_ifft, :plan_bfft] @eval begin - AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = - AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region) + AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), region) - AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x)) = - AbstractFFTs.$plan(ForwardDiff.value.(x), region) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), region) end end # rfft only accepts real arrays -AbstractFFTs.plan_rfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = - AbstractFFTs.plan_rfft(ForwardDiff.value.(x), region) +AbstractFFTs.plan_rfft(x::AbstractArray{<:Dual}, region=1:ndims(x)) = + AbstractFFTs.plan_rfft(value.(x), region) for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex? @eval begin - AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x)) = - AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region) + AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), region) - AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, d::Integer, region=1:ndims(x)) = - AbstractFFTs.$plan(ForwardDiff.value.(x), d, region) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, region=1:ndims(x)) = + AbstractFFTs.$plan(value.(x), d, region) end end +# for f in (:dct, :idct) +# pf = Symbol("plan_", f) +# @eval begin +# AbstractFFTs.$f(x::AbstractArray{<:Dual}) = $pf(x) * x +# AbstractFFTs.$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x +# AbstractFFTs.$pf(x::AbstractArray{<:Dual}, region; kws...) = $pf(value.(x), region; kws...) +# AbstractFFTs.$pf(x::AbstractArray{<:Complex}, region; kws...) = $pf(value.(x), region; kws...) +# end +# end + + for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities @eval begin - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:ForwardDiff.Dual}) = + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Dual}) = _apply_plan(p, x) - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}) = + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:Dual}}) = _apply_plan(p, x) end end function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray) - xtil = p * ForwardDiff.value.(x) - dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n - p * ForwardDiff.partials.(x, n) + xtil = p * value.(x) + dxtils = ntuple(npartials(eltype(x))) do n + p * partials.(x, n) end + __apply_plan(tagtype(eltype(x)), xtil, dxtils) +end + +function __apply_plan(T, xtil, dxtils) map(xtil, dxtils...) do val, parts... Complex( - ForwardDiff.Dual(real(val), map(real, parts)), - ForwardDiff.Dual(imag(val), map(imag, parts)), + Dual{T}(real(val), map(real, parts)), + Dual{T}(imag(val), map(imag, parts)), ) end -end - +end \ No newline at end of file diff --git a/test/FFTTest.jl b/test/FFTTest.jl index aaeb1d8c..f3195eb1 100644 --- a/test/FFTTest.jl +++ b/test/FFTTest.jl @@ -1,7 +1,7 @@ module FFTTest using Test -using ForwardDiff: Dual, valtype, value, partials +using ForwardDiff: Dual, valtype, value, partials, derivative using FFTW using AbstractFFTs: complexfloat, realfloat @@ -21,6 +21,12 @@ x1 = Dual.(1:4.0, 2:5, 3:6) @test partials.(f(x1), 1) == f(partials.(x1, 1)) end +f = x -> real(fft([x; 0; 0])[1]) +@test derivative(f,0.1) ≈ 1 +r = x -> real(rfft([x; 0; 0])[1]) +@test derivative(r,0.1) ≈ 1 +# c = x -> dct([x; 0; 0])[1] +# @test derivative(c,0.1) ≈ 1 end # module From ebabe9e9af58fec0d7898e35e1175f0477476d35 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 3 Aug 2021 11:57:07 +0100 Subject: [PATCH 5/7] Weaken FFTW requirement --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ae1d922b..3a19ecfd 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" DiffRules = "1.2.1" DiffTests = "0.0.1, 0.1" -FFTW = "1.2.4" +FFTW = "1" NaNMath = "0.2.2, 0.3" SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0" StaticArrays = "0.8.3, 0.9, 0.10, 0.11, 0.12, 1.0" From c9c155ad5cfd927b0b425bbb6f7c2a9fe1d5bb5c Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 31 Aug 2021 20:43:57 +0100 Subject: [PATCH 6/7] seed! random --- test/DualTest.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/DualTest.jl b/test/DualTest.jl index 878e9bd2..8091edca 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -14,6 +14,7 @@ import Calculus struct TestTag end samerng() = MersenneTwister(1) +Random.seed!(132) # By lower-bounding the Int range at 2, we avoid cases where differentiating an # exponentiation of an Int value would cause a DomainError due to reducing the From 98e1d2e0899508e5955b73bab028118143474f47 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Sun, 20 Nov 2022 19:32:23 +0000 Subject: [PATCH 7/7] Drop