diff --git a/Project.toml b/Project.toml index 06f2d0e..cd1b476 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" -SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -19,7 +18,6 @@ ChainRules = "1" ChainRulesCore = "1" ChainRulesOverloadGeneration = "0.1" IrrationalConstants = "0.2" -SliceMap = "0.2" SpecialFunctions = "2" SymbolicUtils = "1" Zygote = "0.6.55" diff --git a/src/chainrules.jl b/src/chainrules.jl index d2647a5..4019fce 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -41,6 +41,17 @@ function rrule(::typeof(*), A::AbstractMatrix{S}, return A * t, gemv_pullback end +function rrule(::typeof(*), A::AbstractMatrix{S}, + B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T} + project_A = ProjectTo(A) + project_B = ProjectTo(B) + function gemm_pullback(x̄) + X̄ = unthunk(x̄) + NoTangent(), @thunk(project_A(X̄ * transpose(B))), @thunk(project_B(transpose(A) * X̄)) + end + return A * B, gemm_pullback +end + @adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T} project_v = ProjectTo(v) t + v, x̄ -> (x̄, project_v(x̄)) diff --git a/src/derivative.jl b/src/derivative.jl index 1cf1d6f..62a60c7 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -1,5 +1,4 @@ -using SliceMap export derivative """ @@ -56,10 +55,12 @@ end @inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N} size(x)[1] != 1 && @warn "x is not a row vector." - mapcols(u -> derivative(f, u[1], vN), x) + t = make_taylor.(x, one(N), vN) + return extract_derivative.(f(t), N) end @inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S}, vN::Val{N}) where {T <: TN, S <: TN, N} - mapcols(u -> derivative(f, u, l, vN), x) + t = make_taylor.(x, l, vN) + return extract_derivative.(f(t), N) end