From 0ff0853c42414b9d7424b7cb4d3b82e7a80bd55e Mon Sep 17 00:00:00 2001 From: zhujch Date: Tue, 7 Nov 2023 17:04:51 -0500 Subject: [PATCH 1/2] derivative for matrix --- src/chainrules.jl | 11 +++++++++++ src/derivative.jl | 3 ++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index d2647a5..4019fce 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -41,6 +41,17 @@ function rrule(::typeof(*), A::AbstractMatrix{S}, return A * t, gemv_pullback end +function rrule(::typeof(*), A::AbstractMatrix{S}, + B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T} + project_A = ProjectTo(A) + project_B = ProjectTo(B) + function gemm_pullback(x̄) + X̄ = unthunk(x̄) + NoTangent(), @thunk(project_A(X̄ * transpose(B))), @thunk(project_B(transpose(A) * X̄)) + end + return A * B, gemm_pullback +end + @adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T} project_v = ProjectTo(v) t + v, x̄ -> (x̄, project_v(x̄)) diff --git a/src/derivative.jl b/src/derivative.jl index 1cf1d6f..bc11282 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -61,5 +61,6 @@ end @inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S}, vN::Val{N}) where {T <: TN, S <: TN, N} - mapcols(u -> derivative(f, u, l, vN), x) + t = make_taylor.(x, l, vN) + return extract_derivative.(f(t), N) end From 15b98c34838d91f296f1dc004f07dc938eb2e1a4 Mon Sep 17 00:00:00 2001 From: zhujch Date: Wed, 8 Nov 2023 14:25:44 -0500 Subject: [PATCH 2/2] remove dependency SlicedMap --- Project.toml | 2 -- src/derivative.jl | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 06f2d0e..cd1b476 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" -SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -19,7 +18,6 @@ ChainRules = "1" ChainRulesCore = "1" ChainRulesOverloadGeneration = "0.1" IrrationalConstants = "0.2" -SliceMap = "0.2" SpecialFunctions = "2" SymbolicUtils = "1" Zygote = "0.6.55" diff --git a/src/derivative.jl b/src/derivative.jl index bc11282..62a60c7 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -1,5 +1,4 @@ -using SliceMap export derivative """ @@ -56,7 +55,8 @@ end @inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N} size(x)[1] != 1 && @warn "x is not a row vector." - mapcols(u -> derivative(f, u[1], vN), x) + t = make_taylor.(x, one(N), vN) + return extract_derivative.(f(t), N) end @inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},