From 850cbec72624acf0a710aa2b03b7975cbf559b30 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 22 Jun 2022 19:28:33 -0400 Subject: [PATCH] optimized 1d regression --- src/regression.jl | 60 +++++++++++++++++++++++++++++++++++++---------- test/runtests.jl | 9 +++++++ 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/src/regression.jl b/src/regression.jl index 7db4371..9eaa839 100644 --- a/src/regression.jl +++ b/src/regression.jl @@ -1,18 +1,8 @@ # Chebyshev regression: least-square fits of data # to multidimensional Chebyshev polynomials. -function chebregression(x::AbstractVector{SVector{N,Td}}, y::AbstractVector{T}, - lb::SVector{N,Td}, ub::SVector{N,Td}, order::NTuple{N,Int}) where {N,Td<:Real,T<:Union{SVector,Number}} - length(x) == length(y) || throw(DimensionMismatch()) - length(x) ≥ prod(order .+ 1) || throw(ArgumentError("not enough data points $(length(x)) to fit to order $order")) - - # assemble rhs as matrix - Y = Array{float(eltype(T))}(undef, length(y), length(first(y))) - for j = 1:length(y) - Y[j,:] .= y[j + (firstindex(y)-1)] - end - - # assemble lhs matrix +# assemble Chebyshev-Vandermonde matrix +function _chebvandermonde(x::AbstractVector{SVector{N,Td}}, lb::SVector{N,Td}, ub::SVector{N,Td}, order::NTuple{N,Int}) where {N,Td<:Real} # (TODO: this algorithm is O(length(x) * length(c)²), # but it should be possible to do it in linear time. # However, the A \ Y step is also O(mn²), so this @@ -26,6 +16,52 @@ function chebregression(x::AbstractVector{SVector{N,Td}}, y::AbstractVector{T}, end c.coefs[i] = 0 # reset end + return A +end + +# wrapper around _chebvandermonde for testing convenience +chebvandermonde(x::AbstractVector{SVector{N,Td}}, lb::SVector{N,Td}, ub::SVector{N,Td}, order::NTuple{N,Int}) where {N,Td<:Real} = + _chebvandermonde(x, lb, ub, order) + +# optimized method for 1d case +function chebvandermonde(x::AbstractVector{SVector{1,Td}}, lb::SVector{1,Td}, ub::SVector{1,Td}, order::NTuple{1,Int}) where {Td<:Real} + lb1, ub1, o1 = lb[1], ub[1], order[1] + o1 >= 0 || throw(ArgumentError("order $o1 must be nonnegative")) + A = Array{Td}(undef, length(x), o1+1) + for j = 1:length(x) + xⱼ = (x[j][1] - lb1) * 2 / (ub1 - lb1) - 1 + -1 ≤ xⱼ ≤ 1 || throw(ArgumentError("$(x[j][1]) not in domain [$lb1,$ub1]")) + A[j,1] = Tᵢ₋₂ = 1 + if o1 > 0 + Tᵢ₋₁ = xⱼ + A[j,2] = Tᵢ₋₁ + twoxⱼ = 2xⱼ + for i = 3:o1+1 # Chebyshev recurrence + A[j,i] = Tᵢ = twoxⱼ * Tᵢ₋₁ - Tᵢ₋₂ + Tᵢ₋₂, Tᵢ₋₁ = Tᵢ₋₁, Tᵢ + end + end + end + return A +end + +# convenient API for 1d case +chebvandermonde(x::AbstractVector{Td}, lb::Real, ub::Real, order::Integer) where {Td<:Real} = + return chebvandermonde(reinterpret(SVector{1,Td}, x), SVector{1,Td}(lb), SVector{1,Td}(ub), (Int(order),)) + +function chebregression(x::AbstractVector{SVector{N,Td}}, y::AbstractVector{T}, + lb::SVector{N,Td}, ub::SVector{N,Td}, order::NTuple{N,Int}) where {N,Td<:Real,T<:Union{SVector,Number}} + length(x) == length(y) || throw(DimensionMismatch()) + length(x) ≥ prod(order .+ 1) || throw(ArgumentError("not enough data points $(length(x)) to fit to order $order")) + + # assemble rhs as matrix + Y = Array{float(eltype(T))}(undef, length(y), length(first(y))) + for j = 1:length(y) + Y[j,:] .= y[j + (firstindex(y)-1)] + end + + # assemble lhs matrix + A = chebvandermonde(x, lb, ub, order) # least-square solution C = A \ Y diff --git a/test/runtests.jl b/test/runtests.jl index 26a194a..9793e64 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,6 +58,15 @@ end c = chebregression(x, y, 2) c2 = x.^(0:2)' \ y # Vandermonde-style fit @test c(1.1) ≈ c2[1] + 1.1 * c2[2] + 1.1^2 * c2[3] rtol=1e-13 + + # test specialized Vandermonde matrix constructor + x = [0.14, 0.95, 0.83, 0.13, 0.42, 0.12] + xv = reinterpret(SVector{1,Float64}, x) + A1 = FastChebInterp.chebvandermonde(x, 0.1, 0.99, 4) + A = FastChebInterp._chebvandermonde(xv, SVector(0.1), SVector(0.99), (4,)) + @test A ≈ A1 rtol=1e-13 + @test_throws ArgumentError FastChebInterp.chebvandermonde(x, 0.1, 0.9, 4) + @test_throws ArgumentError FastChebInterp._chebvandermonde(xv, SVector(0.1), SVector(0.9), (4,)) end @testset "2d regression" begin