diff --git a/src/chainrules.jl b/src/chainrules.jl index 0c7269d1..68ebcce3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -31,10 +31,12 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, end function rrule(::typeof(*), A::AbstractMatrix{S}, - t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T} + t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Number, T} project_A = ProjectTo(A) function gemv_pullback(x̄) - NoTangent(), @thunk(project_A(contract.(x̄, transpose(t)))), @thunk(transpose(A)*x̄) + x̂ = reinterpret(reshape, T, x̄) + t̂ = reinterpret(reshape, T, t) + NoTangent(), @thunk(project_A(transpose(x̂) * t̂)), @thunk(transpose(A)*x̄) end return A * t, gemv_pullback end diff --git a/src/derivative.jl b/src/derivative.jl index 00cbcc05..81b66da7 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -18,8 +18,8 @@ function derivative end derivative(f, x, Val{order + 1}()) end -@inline function derivative(f, x::V, l::V, - order::Int64) where {V <: AbstractArray{<:Number, 1}} +@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, + order::Int64) where {T <: Number, S <: Number} derivative(f, x, l, Val{order + 1}()) end @@ -29,10 +29,10 @@ end end # Need to rewrite like this to help Zygote infer types -make_taylor(t0::T, t1::T, ::Val{N}) where {T, N} = TaylorScalar{T, N}(t0, t1) +make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) -@inline function derivative(f, x::V, l::V, - vN::Val{N}) where {V <: AbstractArray{<:Number, 1}, N} +@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, + vN::Val{N}) where {T <: Number, S <: Number, N} t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) end