From 1193331049c6d2f1567f753eb8646b96de75cf00 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Tue, 23 May 2023 15:50:45 -0600 Subject: [PATCH] Update chainrules.jl Update chainrules.jl relax type format --- src/chainrules.jl | 6 ++++-- src/derivative.jl | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 0c7269d1..68ebcce3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -31,10 +31,12 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, end function rrule(::typeof(*), A::AbstractMatrix{S}, - t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T} + t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Number, T} project_A = ProjectTo(A) function gemv_pullback(x̄) - NoTangent(), @thunk(project_A(contract.(x̄, transpose(t)))), @thunk(transpose(A)*x̄) + x̂ = reinterpret(reshape, T, x̄) + t̂ = reinterpret(reshape, T, t) + NoTangent(), @thunk(project_A(transpose(x̂) * t̂)), @thunk(transpose(A)*x̄) end return A * t, gemv_pullback end diff --git a/src/derivative.jl b/src/derivative.jl index 00cbcc05..81b66da7 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -18,8 +18,8 @@ function derivative end derivative(f, x, Val{order + 1}()) end -@inline function derivative(f, x::V, l::V, - order::Int64) where {V <: AbstractArray{<:Number, 1}} +@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, + order::Int64) where {T <: Number, S <: Number} derivative(f, x, l, Val{order + 1}()) end @@ -29,10 +29,10 @@ end end # Need to rewrite like this to help Zygote infer types -make_taylor(t0::T, t1::T, ::Val{N}) where {T, N} = TaylorScalar{T, N}(t0, t1) +make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) -@inline function derivative(f, x::V, l::V, - vN::Val{N}) where {V <: AbstractArray{<:Number, 1}, N} +@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, + vN::Val{N}) where {T <: Number, S <: Number, N} t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) end