Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/FastChebInterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ struct ChebPoly{N,T,Td<:Real} <: Function
ub::SVector{N,Td} # of the domain
end

function Base.show(io::IO, c::ChebPoly)
print(io, "Chebyshev order ", map(i->i-1,size(c.coefs)), " polynomial on ",
function Base.show(io::IO, c::ChebPoly{N,T,Td}) where {N,T,Td}
print(io, "ChebPoly{$N,$T,$Td} order ", map(i->i-1,size(c.coefs)), " polynomial on ",
'[', c.lb[1], ',', c.ub[1], ']')
for i = 2:length(c.lb)
print(io, " × [", c.lb[i], ',', c.ub[i], ']')
end
end

# need explicit 3-arg show so that we don't call the
# 3-arg ::Function method:
Base.show(io::IO, ::MIME"text/plain", c::ChebPoly) = show(io, c)

Base.ndims(c::ChebPoly) = ndims(c.coefs)
Base.zero(c::ChebPoly{N,T,Td}) where {N,T,Td} = ChebPoly{N,T,Td}(zero(c.coefs), c.lb, c.ub)

Expand Down
4 changes: 2 additions & 2 deletions src/eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Evaluate the Chebyshev polynomial given by `interp` at the point `x`.
"""
@fastmath function (interp::ChebPoly{N})(x::SVector{N,<:Real}) where {N}
x0 = @. (x - interp.lb) * 2 / (interp.ub - interp.lb) - 1
all(abs.(x0) .≤ 1) || throw(ArgumentError("$x not in domain"))
all(abs.(x0) .≤ 1) || throw(ArgumentError("$x not in domain $(interp.lb) to $(interp.ub)"))
return evaluate(x0, interp.coefs, Val{N}(), 1, length(interp.coefs))
end

Expand Down Expand Up @@ -146,7 +146,7 @@ is a 1-row matrix; in this case you may wish to call `chebgradient` instead.
"""
function chebjacobian(c::ChebPoly{N}, x::SVector{N,<:Real}) where {N}
x0 = @. (x - c.lb) * 2 / (c.ub - c.lb) - 1
all(abs.(x0) .≤ 1) || throw(ArgumentError("$x not in domain"))
all(abs.(x0) .≤ 1) || throw(ArgumentError("$x not in domain $(c.lb) to $(c.ub)"))
v, J = Jevaluate(x0, c.coefs, Val{N}(), 1, length(c.coefs))
return v, J .* 2 ./ (c.ub .- c.lb)'
end
Expand Down
6 changes: 4 additions & 2 deletions src/interp.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# "fitting" (actually just interpolating) Chebyshev polynomials
# to functions evaluated at Chebyshev points.

chebpoint(i::CartesianIndex{N}, order::NTuple{N,Int}, lb::SVector{N}, ub::SVector{N}) where {N} =
@. lb + (1 + cos($SVector($Tuple(i)) * π / $SVector(ifelse.(iszero.(order),2,order)))) * (ub - lb) * 0.5
function chebpoint(i::CartesianIndex{N}, order::NTuple{N,Int}, lb::SVector{N}, ub::SVector{N}) where {N}
T = typeof(float(one(eltype(lb)) * one(eltype(ub))))
@. lb + (1 + cos(T($SVector($Tuple(i))) * π / $SVector(ifelse.(iszero.(order),2,order)))) * (ub - lb) * $(T(0.5))
end

chebpoints(order::NTuple{N,Int}, lb::SVector{N}, ub::SVector{N}) where {N} =
[chebpoint(i,order,lb,ub) for i in CartesianIndices(map(n -> n==0 ? (1:1) : (0:n), order))]
Expand Down
110 changes: 60 additions & 50 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,63 +1,73 @@
using Test, FastChebInterp, StaticArrays, Random, ChainRulesTestUtils

# similar to ≈, but acts elementwise on tuples
≈′(a::Tuple, b::Tuple; kws...) where {N} = length(a) == length(b) && all(xy -> isapprox(xy[1],xy[2]; kws...), zip(a,b))
≈′(a::Tuple, b::Tuple; kws...) = length(a) == length(b) && all(xy -> isapprox(xy[1],xy[2]; kws...), zip(a,b))

Random.seed!(314159) # make chainrules tests deterministic

@testset "1d test" begin
lb,ub = -0.3, 0.9
f(x) = exp(x) / (1 + 2x^2)
f′(x) = f(x) * (1 - 4x/(1 + 2x^2))
@test_throws ArgumentError chebpoints(-1, lb, ub)
x = chebpoints(48, lb, ub)
interp = chebinterp(f.(x), lb, ub)
@test ndims(interp) == 1
x1 = 0.2
@test interp(x1) ≈ f(x1)
@test chebgradient(interp, x1) ≈′ (f(x1), f′(x1))
test_frule(interp, x1)
test_rrule(interp, x1)
for T in (Float32, Float64)
lb,ub = T(-0.3), T(0.9)
f(x) = exp(x) / (1 + 2x^2)
f′(x) = f(x) * (1 - 4x/(1 + 2x^2))
@test_throws ArgumentError chebpoints(-1, lb, ub)
x = chebpoints(48, lb, ub)
@test eltype(x) == T
interp = chebinterp(f.(x), lb, ub, tol=0)
@test interp isa FastChebInterp.ChebPoly{1,T,T}
@test repr("text/plain", interp) == "ChebPoly{1,$T,$T} order (48,) polynomial on [-0.3,0.9]"
@test ndims(interp) == 1
x1 = T(0.2)
@test interp(x1) ≈ f(x1)
@test chebgradient(interp, x1) ≈′ (f(x1), f′(x1))
test_frule(interp, x1, rtol=sqrt(eps(T)), atol=sqrt(eps(T)))
test_rrule(interp, x1, rtol=sqrt(eps(T)), atol=sqrt(eps(T)))
end
end

@testset "2d test" begin
lb, ub = [-0.3,0.1], [0.9,1.2]
f(x) = exp(x[1]+2*x[2]) / (1 + 2x[1]^2 + x[2]^2)
∇f(x) = f(x) * SVector(1 - 4x[1]/(1 + 2x[1]^2 + x[2]^2), 2 - 2x[2]/(1 + 2x[1]^2 + x[2]^2))
x = chebpoints((48,39), lb, ub)
interp = chebinterp(f.(x), lb, ub)
interp0 = chebinterp(f.(x), lb, ub, tol=0)
@test ndims(interp) == 2
x1 = [0.2, 0.3]
@test interp(x1) ≈ f(x1)
@test interp(x1) ≈ interp0(x1) rtol=1e-15
@test all(n -> n[1] < n[2], zip(size(interp.coefs), size(interp0.coefs)))
@test chebgradient(interp, x1) ≈′ (f(x1), ∇f(x1))
test_frule(interp, x1)
test_rrule(interp, x1)

# univariate function in 2d should automatically drop down to univariate polynomial
f1(x) = exp(x[1]) / (1 + 2x[1]^2)
interp1 = chebinterp(f1.(x), lb, ub)
@test interp1(x1) ≈ f1(x1)
@test size(interp1.coefs, 2) == 1 # second dimension should have been dropped

# complex and vector-valued interpolants:
f2(x) = [f(x), cis(x[1]*x[2] + 2x[2])]
∇f2(x) = vcat(transpose(∇f(x)), transpose(SVector(im*x[2], im*(x[1] + 2)) * cis(x[1]*x[2] + 2x[2])))
interp2 = chebinterp(f2.(x), lb, ub)
@test interp2(x1) ≈ f2(x1)
@test chebjacobian(interp2, x1) ≈′ (f2(x1), ∇f2(x1))
test_frule(interp2, x1)
test_rrule(interp2, x1)

# chebinterp_v1
av1 = Array{ComplexF64}(undef, 2, size(x)...)
av1[1,:,:] .= f.(x)
av1[2,:,:] .= (x -> f2(x)[2]).(x)
interp2v1 = chebinterp_v1(av1, lb, ub)
@test interp2v1(x1) ≈ f2(x1)
@test chebjacobian(interp2v1, x1) ≈′ (f2(x1), ∇f2(x1))
for T in (Float32, Float64)
lb, ub = T[-0.3,0.1], T[0.9,1.2]
f(x) = exp(x[1]+2*x[2]) / (1 + 2x[1]^2 + x[2]^2)
∇f(x) = f(x) * SVector(1 - 4x[1]/(1 + 2x[1]^2 + x[2]^2), 2 - 2x[2]/(1 + 2x[1]^2 + x[2]^2))
x = chebpoints((48,39), lb, ub)
@test eltype(x) == SVector{2,T}
interp = chebinterp(f.(x), lb, ub)
@test interp isa FastChebInterp.ChebPoly{2,T,T}
interp0 = chebinterp(f.(x), lb, ub, tol=0)
@test repr("text/plain", interp0) == "ChebPoly{2,$T,$T} order (48, 39) polynomial on [-0.3,0.9] × [0.1,1.2]"
@test ndims(interp) == 2
x1 = T[0.2, 0.3]
@test interp(x1) ≈ f(x1)
@test interp(x1) ≈ interp0(x1) rtol=10eps(T)
@test all(n -> n[1] < n[2], zip(size(interp.coefs), size(interp0.coefs)))
@test chebgradient(interp, x1) ≈′ (f(x1), ∇f(x1))
test_frule(interp, x1, rtol=sqrt(eps(T)), atol=sqrt(eps(T)))
test_rrule(interp, x1, rtol=sqrt(eps(T)), atol=sqrt(eps(T)))

# univariate function in 2d should automatically drop down to univariate polynomial
f1(x) = exp(x[1]) / (1 + 2x[1]^2)
interp1 = chebinterp(f1.(x), lb, ub)
@test interp1(x1) ≈ f1(x1)
@test size(interp1.coefs, 2) == 1 # second dimension should have been dropped

# complex and vector-valued interpolants:
f2(x) = [f(x), cis(x[1]*x[2] + 2x[2])]
∇f2(x) = vcat(transpose(∇f(x)), transpose(SVector(im*x[2], im*(x[1] + 2)) * cis(x[1]*x[2] + 2x[2])))
interp2 = chebinterp(f2.(x), lb, ub)
@test interp2(x1) ≈ f2(x1)
@test chebjacobian(interp2, x1) ≈′ (f2(x1), ∇f2(x1))
test_frule(interp2, x1, rtol=sqrt(eps(T)), atol=sqrt(eps(T)))
test_rrule(interp2, x1, rtol=sqrt(eps(T)), atol=sqrt(eps(T)))

# chebinterp_v1
av1 = Array{Complex{T}}(undef, 2, size(x)...)
av1[1,:,:] .= f.(x)
av1[2,:,:] .= (x -> f2(x)[2]).(x)
interp2v1 = chebinterp_v1(av1, lb, ub)
@test interp2v1(x1) ≈ f2(x1)
@test chebjacobian(interp2v1, x1) ≈′ (f2(x1), ∇f2(x1))
end
end

@testset "1d regression" begin
Expand Down