From b79afb63cf20a91912419bc9a8a5444b2d4e525a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 07:00:34 +0200 Subject: [PATCH 01/13] Methods with and without extras --- .../src/first_order/derivative.jl | 56 ++++++---- .../src/first_order/gradient.jl | 32 +++--- .../src/first_order/jacobian.jl | 100 +++++++----------- .../src/first_order/pullback.jl | 82 ++++++++------ .../src/first_order/pushforward.jl | 87 ++++++++------- .../src/second_order/hessian.jl | 61 ++++++----- .../src/second_order/hvp.jl | 52 ++++----- .../src/second_order/second_derivative.jl | 68 ++++++------ 8 files changed, 280 insertions(+), 258 deletions(-) diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 261d263b5..3baa32338 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -72,46 +72,66 @@ end ## One argument +function value_and_derivative(f::F, backend::AbstractADType, x) where {F} + return value_and_derivative(f, backend, x, prepare_derivative(f, backend, x)) +end + +function value_and_derivative!(f::F, der, backend::AbstractADType, x) where {F} + return value_and_derivative!(f, der, backend, x, prepare_derivative(f, backend, x)) +end + +function derivative(f::F, backend::AbstractADType, x) where {F} + return derivative(f, backend, x, prepare_derivative(f, backend, x)) +end + +function derivative!(f::F, der, backend::AbstractADType, x) where {F} + return derivative!(f, der, backend, x, prepare_derivative(f, backend, x)) +end + function value_and_derivative( - f::F, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward(f, backend, x, one(x), extras.pushforward_extras) end function value_and_derivative!( - f::F, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward!(f, der, backend, x, one(x), extras.pushforward_extras) end function derivative( - f::F, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward(f, backend, x, one(x), extras.pushforward_extras) end function derivative!( - f::F, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward!(f, der, backend, x, one(x), extras.pushforward_extras) end ## Two arguments +function value_and_derivative(f!::F, y, backend::AbstractADType, x) where {F} + return value_and_derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x)) +end + +function value_and_derivative!(f!::F, y, der, backend::AbstractADType, x) where {F} + return value_and_derivative!( + f!, y, der, backend, x, prepare_derivative(f!, y, backend, x) + ) +end + +function derivative(f!::F, y, backend::AbstractADType, x) where {F} + return derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x)) +end + +function derivative!(f!::F, y, der, backend::AbstractADType, x) where {F} + return derivative!(f!, y, der, backend, x, prepare_derivative(f!, y, backend, x)) +end + function value_and_derivative( f!::F, y, diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 35fcbd41a..aab866768 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -62,34 +62,42 @@ end ## One argument +function value_and_gradient(f::F, backend::AbstractADType, x) where {F} + return value_and_gradient(f, backend, x, prepare_gradient(f, backend, x)) +end + +function value_and_gradient!(f::F, der, backend::AbstractADType, x) where {F} + return value_and_gradient!(f, der, backend, x, prepare_gradient(f, backend, x)) +end + +function gradient(f::F, backend::AbstractADType, x) where {F} + return gradient(f, backend, x, prepare_gradient(f, backend, x)) +end + +function gradient!(f::F, der, backend::AbstractADType, x) where {F} + return gradient!(f, der, backend, x, prepare_gradient(f, backend, x)) +end + function value_and_gradient( - f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x) + f::F, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return value_and_pullback(f, backend, x, one(eltype(x)), extras.pullback_extras) end function value_and_gradient!( - f::F, - grad, - backend::AbstractADType, - x, - extras::GradientExtras=prepare_gradient(f, backend, x), + f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return value_and_pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras) end function gradient( - f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x) + f::F, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return pullback(f, backend, x, one(eltype(x)), extras.pullback_extras) end function gradient!( - f::F, - grad, - backend::AbstractADType, - x, - extras::GradientExtras=prepare_gradient(f, backend, x), + f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 9847a02fc..825071169 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -98,13 +98,23 @@ end ## One argument -function value_and_jacobian( - f::F, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x) -) where {F} - return value_and_jacobian_onearg_aux(f, backend, x, extras) +function value_and_jacobian(f!::F, y, backend::AbstractADType, x) where {F} + return value_and_jacobian(f!, y, backend, x, prepare_jacobian(f, backend, x)) +end + +function value_and_jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} + return value_and_jacobian!(f!, y, jac, backend, x, prepare_jacobian(f, backend, x)) +end + +function jacobian(f!::F, y, backend::AbstractADType, x) where {F} + return jacobian(f!, y, backend, x, prepare_jacobian(f, backend, x)) +end + +function jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} + return jacobian!(f!, y, jac, backend, x, prepare_jacobian(f, backend, x)) end -function value_and_jacobian_onearg_aux( +function value_and_jacobian( f::F, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -123,7 +133,7 @@ function value_and_jacobian_onearg_aux( return y, jac end -function value_and_jacobian_onearg_aux( +function value_and_jacobian( f::F, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -139,16 +149,6 @@ function value_and_jacobian_onearg_aux( end function value_and_jacobian!( - f::F, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f, backend, x), -) where {F} - return value_and_jacobian_onearg_aux!(f, jac, backend, x, extras) -end - -function value_and_jacobian_onearg_aux!( f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -167,7 +167,7 @@ function value_and_jacobian_onearg_aux!( return y, jac end -function value_and_jacobian_onearg_aux!( +function value_and_jacobian!( f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -182,35 +182,33 @@ function value_and_jacobian_onearg_aux!( return y, jac end -function jacobian( - f::F, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x) -) where {F} +function jacobian(f::F, backend::AbstractADType, x, extras::JacobianExtras) where {F} return value_and_jacobian(f, backend, x, extras)[2] end -function jacobian!( - f::F, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f, backend, x), -) where {F} +function jacobian!(f::F, jac, backend::AbstractADType, x, extras::JacobianExtras) where {F} return value_and_jacobian!(f, jac, backend, x, extras)[2] end ## Two arguments -function value_and_jacobian( - f!::F, - y, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), -) where {F} - return value_and_jacobian_twoarg_aux(f!, y, backend, x, extras) +function value_and_jacobian(f!::F, y, backend::AbstractADType, x) where {F} + return value_and_jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x)) +end + +function value_and_jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} + return value_and_jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x)) +end + +function jacobian(f!::F, y, backend::AbstractADType, x) where {F} + return jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x)) end -function value_and_jacobian_twoarg_aux( +function jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} + return jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x)) +end + +function value_and_jacobian( f!::F, y, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} pushforward_extras_same = prepare_pushforward_same_point( @@ -230,7 +228,7 @@ function value_and_jacobian_twoarg_aux( return y, jac end -function value_and_jacobian_twoarg_aux( +function value_and_jacobian( f!::F, y, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} pullback_extras_same = prepare_pullback_same_point( @@ -251,17 +249,6 @@ function value_and_jacobian_twoarg_aux( end function value_and_jacobian!( - f!::F, - y, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), -) where {F} - return value_and_jacobian_twoarg_aux!(f!, y, jac, backend, x, extras) -end - -function value_and_jacobian_twoarg_aux!( f!::F, y, jac::AbstractMatrix, @@ -286,7 +273,7 @@ function value_and_jacobian_twoarg_aux!( return y, jac end -function value_and_jacobian_twoarg_aux!( +function value_and_jacobian!( f!::F, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} pullback_extras_same = prepare_pullback_same_point( @@ -306,23 +293,12 @@ function value_and_jacobian_twoarg_aux!( return y, jac end -function jacobian( - f!::F, - y, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), -) where {F} +function jacobian(f!::F, y, backend::AbstractADType, x, extras::JacobianExtras) where {F} return value_and_jacobian(f!, y, backend, x, extras)[2] end function jacobian!( - f!::F, - y, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), + f!::F, y, jac, backend::AbstractADType, x, extras::JacobianExtras ) where {F} return value_and_jacobian!(f!, y, jac, backend, x, extras)[2] end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index ed7119cd9..b5f343802 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -95,8 +95,14 @@ function prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackSlow) where {F end # Throw error if backend is missing -prepare_pullback_aux(f::F, backend, x, dy, ::PullbackFast) where {F} = throw(MissingBackendError(backend)) -prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackFast) where {F} = throw(MissingBackendError(backend)) + +function prepare_pullback_aux(f, backend, x, dy, ::PullbackFast) + throw(MissingBackendError(backend)) +end + +function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast) + throw(MissingBackendError(backend)) +end ## Preparation (same point) @@ -124,6 +130,22 @@ end ## One argument +function value_and_pullback(f::F, backend::AbstractADType, x, dy) where {F} + return value_and_pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy)) +end + +function value_and_pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F} + return value_and_pullback!(f, dx, backend, dy, x, prepare_pullback(f, backend, x, dy)) +end + +function pullback(f::F, backend::AbstractADType, x, dy) where {F} + return pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy)) +end + +function pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F} + return pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy)) +end + function value_and_pullback( f::F, backend::AbstractADType, @@ -131,10 +153,10 @@ function value_and_pullback( dy, extras::PullbackExtras=prepare_pullback(f, backend, x, dy), ) where {F} - return value_and_pullback_onearg_aux(f, backend, x, dy, extras) + return value_and_pullback(f, backend, x, dy, extras) end -function value_and_pullback_onearg_aux( +function value_and_pullback( f::F, backend, x, dy, extras::PushforwardPullbackExtras ) where {F} @compat (; pushforward_extras) = extras @@ -190,18 +212,27 @@ end ## Two arguments -function value_and_pullback( - f!::F, - y, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), -) where {F} - return value_and_pullback_twoarg_aux(f!, y, backend, x, dy, extras) +function value_and_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F} + return value_and_pullback( + f!, y, backend, x, dy, prepare_pullback(f!, y, backend, x, dy) + ) end -function value_and_pullback_twoarg_aux( +function value_and_pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F} + return value_and_pullback!( + f!, y, dx, backend, dy, x, prepare_pullback(f!, y, backend, x, dy) + ) +end + +function pullback(f!::F, y, backend::AbstractADType, x, dy) where {F} + return pullback(f!, y, backend, x, dy, prepare_pullback(f!, y, backend, x, dy)) +end + +function pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F} + return pullback!(f!, y, dx, backend, x, dy, prepare_pullback(f!, y, backend, x, dy)) +end + +function value_and_pullback( f!::F, y, backend, x, dy, extras::PushforwardPullbackExtras ) where {F} @compat (; pushforward_extras) = extras @@ -217,37 +248,20 @@ function value_and_pullback_twoarg_aux( end function value_and_pullback!( - f!::F, - y, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), + f!::F, y, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} y, new_dx = value_and_pullback(f!, y, backend, x, dy, extras) return y, copyto!(dx, new_dx) end function pullback( - f!::F, - y, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), + f!::F, y, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} return value_and_pullback(f!, y, backend, x, dy, extras)[2] end function pullback!( - f!::F, - y, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), + f!::F, y, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} return value_and_pullback!(f!, y, dx, backend, x, dy, extras)[2] end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index db1a6784d..4ac318c69 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -96,8 +96,14 @@ function prepare_pushforward_aux(f!::F, y, backend, x, dx, ::PushforwardSlow) wh end # Throw error if backend is missing -prepare_pushforward_aux(f::F, backend, x, dy, ::PushforwardFast) where {F} = throw(MissingBackendError(backend)) -prepare_pushforward_aux(f!::F, y, backend, x, dy, ::PushforwardFast) where {F} = throw(MissingBackendError(backend)) + +function prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast) + throw(MissingBackendError(backend)) +end + +function prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast) + throw(MissingBackendError(backend)) +end ## Preparation (same point) @@ -125,17 +131,25 @@ end ## One argument -function value_and_pushforward( - f::F, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), -) where {F} - return value_and_pushforward_onearg_aux(f, backend, x, dx, extras) +function value_and_pushforward(f::F, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward(f, backend, x, dx, prepare_pushforward(f, backend, x, dx)) +end + +function value_and_pushforward!(f::F, dy, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward!( + f, dy, backend, x, dx, prepare_pushforward(f, backend, x, dx) + ) end -function value_and_pushforward_onearg_aux( +function pushforward(f::F, backend::AbstractADType, x, dx) where {F} + return pushforward(f, backend, x, dx, prepare_pushforward(f, backend, x, dx)) +end + +function pushforward!(f::F, dy, backend::AbstractADType, x, dx) where {F} + return pushforward!(f, dy, backend, x, dx, prepare_pushforward(f, backend, x, dx)) +end + +function value_and_pushforward( f::F, backend, x, dx, extras::PullbackPushforwardExtras ) where {F} @compat (; pullback_extras) = extras @@ -157,52 +171,49 @@ function value_and_pushforward_onearg_aux( end function value_and_pushforward!( - f::F, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), + f::F, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} y, new_dy = value_and_pushforward(f, backend, x, dx, extras) return y, copyto!(dy, new_dy) end function pushforward( - f::F, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), + f::F, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward(f, backend, x, dx, extras)[2] end function pushforward!( - f::F, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), + f::F, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward!(f, dy, backend, x, dx, extras)[2] end ## Two arguments -function value_and_pushforward( - f!::F, - y, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), -) where {F} - return value_and_pushforward_twoarg_aux(f!, y, backend, x, dx, extras) +function value_and_pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward( + f!, y, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx) + ) +end + +function value_and_pushforward!(f!::F, y, dy, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward!( + f!, y, dy, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx) + ) end -function value_and_pushforward_twoarg_aux( +function pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F} + return pushforward(f!, y, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx)) +end + +function pushforward!(f!::F, y, dy, backend::AbstractADType, x, dx) where {F} + return pushforward!( + f!, y, dy, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx) + ) +end + +function value_and_pushforward( f!::F, y, backend, x, dx, extras::PullbackPushforwardExtras ) where {F} @compat (; pullback_extras) = extras diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 267492eb1..90fd64642 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -25,18 +25,18 @@ Compute the Hessian matrix of the function `f` at point `x`, overwriting `hess`. function hessian! end """ - value_gradient_and_hessian(f, backend, x, [extras]) -> (y, grad, hess) + value_hessian_and_hessian(f, backend, x, [extras]) -> (y, grad, hess) -Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`. +Compute the value, hessian vector and Hessian matrix of the function `f` at point `x`. """ -function value_gradient_and_hessian end +function value_hessian_and_hessian end """ - value_gradient_and_hessian!(f, grad, hess, backend, x, [extras]) -> (y, grad, hess) + value_hessian_and_hessian!(f, grad, hess, backend, x, [extras]) -> (y, grad, hess) -Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. +Compute the value, hessian vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. """ -function value_gradient_and_hessian! end +function value_hessian_and_hessian! end ## Preparation @@ -51,20 +51,38 @@ struct NoHessianExtras <: HessianExtras end struct HVPGradientHessianExtras{E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras hvp_extras::E2 - gradient_extras::E1 + hessian_extras::E1 end function prepare_hessian(f::F, backend::AbstractADType, x) where {F} v = basis(backend, x, first(CartesianIndices(x))) hvp_extras = prepare_hvp(f, backend, x, v) - gradient_extras = prepare_gradient(f, maybe_inner(backend), x) - return HVPGradientHessianExtras(hvp_extras, gradient_extras) + hessian_extras = prepare_hessian(f, maybe_inner(backend), x) + return HVPGradientHessianExtras(hvp_extras, hessian_extras) end ## One argument +function value_gradient_and_hessian(f::F, backend::AbstractADType, x) where {F} + return value_gradient_and_hessian(f, backend, x, prepare_hessian(f, backend, x)) +end + +function value_gradient_and_hessian!(f::F, grad, hess, backend::AbstractADType, x) where {F} + return value_gradient_and_hessian!( + f, grad, hess, backend, x, prepare_hessian(f, backend, x) + ) +end + +function hessian(f::F, backend::AbstractADType, x) where {F} + return hessian(f, backend, x, prepare_hessian(f, backend, x)) +end + +function hessian!(f::F, hess, backend::AbstractADType, x) where {F} + return hessian!(f, hess, backend, x, prepare_hessian(f, backend, x)) +end + function hessian( - f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) + f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -77,11 +95,7 @@ function hessian( end function hessian!( - f::F, - hess, - backend::AbstractADType, - x, - extras::HessianExtras=prepare_hessian(f, backend, x), + f::F, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -93,23 +107,18 @@ function hessian!( return hess end -function value_gradient_and_hessian( - f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) +function value_hessian_and_hessian( + f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} - y, grad = value_and_gradient(f, maybe_inner(backend), x, extras.gradient_extras) + y, grad = value_and_hessian(f, maybe_inner(backend), x, extras.hessian_extras) hess = hessian(f, backend, x, extras) return y, grad, hess end -function value_gradient_and_hessian!( - f::F, - grad, - hess, - backend::AbstractADType, - x, - extras::HessianExtras=prepare_hessian(f, backend, x), +function value_hessian_and_hessian!( + f::F, grad, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} - y, _ = value_and_gradient!(f, grad, maybe_inner(backend), x, extras.gradient_extras) + y, _ = value_and_hessian!(f, grad, maybe_inner(backend), x, extras.hessian_extras) hessian!(f, hess, backend, x, extras) return y, grad, hess end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index d323db06c..8c078e0b7 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -80,10 +80,10 @@ function prepare_hvp(f::F, backend::AbstractADType, x, v) where {F} end function prepare_hvp(f::F, backend::SecondOrder, x, v) where {F} - return prepare_hvp_aux(f, backend, x, v, hvp_mode(backend)) + return prepare_hvp(f, backend, x, v, hvp_mode(backend)) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) where {F} # pushforward of many pushforwards in theory, but pushforward of gradient in practice inner_backend = nested(inner(backend)) inner_gradient_closure(z) = gradient(f, inner_backend, z) @@ -93,7 +93,7 @@ function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) return ForwardOverForwardHVPExtras(inner_gradient_closure, outer_pushforward_extras) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) where {F} # pushforward of gradient inner_backend = nested(inner(backend)) inner_gradient_closure(z) = gradient(f, inner_backend, z) @@ -103,7 +103,7 @@ function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) return ForwardOverReverseHVPExtras(inner_gradient_closure, outer_pushforward_extras) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) where {F} # gradient of pushforward # uses v in the closure inner_backend = nested(inner(backend)) @@ -119,7 +119,7 @@ function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) ) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverReverse) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ReverseOverReverse) where {F} # pullback of the gradient inner_backend = nested(inner(backend)) inner_gradient_closure(z) = gradient(f, inner_backend, z) @@ -142,76 +142,68 @@ end ## One argument -function hvp( - f::F, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) -) where {F} - return hvp(f, SecondOrder(backend, backend), x, v, extras) +function hvp(f::F, backend::AbstractADType, x, v) where {F} + return hvp(f, backend, x, prepare_hvp(f, backend, x, v)) end -function hvp( - f::F, backend::SecondOrder, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) -) where {F} - return hvp_aux(f, backend, x, v, extras) +function hvp!(f::F, p, backend::AbstractADType, x, v) where {F} + return hvp!(f, p, backend, x, prepare_hvp(f, backend, x, v)) +end + +function hvp(f::F, backend::AbstractADType, x, v, extras::HVPExtras) where {F} + return hvp(f, SecondOrder(backend, backend), x, v, extras) end -function hvp_aux(f::F, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} +function hvp(f::F, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward( inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux(f::F, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} +function hvp(f::F, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward( inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux(f::F, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} +function hvp(f::F, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras inner_pushforward_closure = inner_pushforward_closure_generator(v) return gradient(inner_pushforward_closure, outer(backend), x, outer_gradient_extras) end -function hvp_aux(f::F, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} +function hvp(f::F, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pullback_extras) = extras return pullback(inner_gradient_closure, outer(backend), x, v, outer_pullback_extras) end -function hvp!( - f::F, p, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) -) where {F} +function hvp!(f::F, p, backend::AbstractADType, x, v, extras::HVPExtras) where {F} return hvp!(f, p, SecondOrder(backend, backend), x, v, extras) end -function hvp!( - f::F, p, backend::SecondOrder, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) -) where {F} - return hvp_aux!(f, p, backend, x, v, extras) -end - -function hvp_aux!(f::F, p, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} +function hvp!(f::F, p, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward!( inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux!(f::F, p, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} +function hvp!(f::F, p, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward!( inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux!(f::F, p, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} +function hvp!(f::F, p, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras inner_pushforward_closure = inner_pushforward_closure_generator(v) return gradient!(inner_pushforward_closure, p, outer(backend), x, outer_gradient_extras) end -function hvp_aux!(f::F, p, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} +function hvp!(f::F, p, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pullback_extras) = extras return pullback!(inner_gradient_closure, p, outer(backend), x, v, outer_pullback_extras) end diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 5921a8a29..c16c08309 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -69,30 +69,43 @@ end ## One argument +function value_derivative_and_second_derivative(f::F, backend::AbstractADType, x) where {F} + return value_derivative_and_second_derivative( + f, backend, x, prepare_second_derivative(f, backend, x) + ) +end + +function value_derivative_and_second_derivative!( + f::F, der, der2, backend::AbstractADType, x +) where {F} + return value_derivative_and_second_derivative!( + f, der, der2, backend, x, prepare_second_derivative(f, backend, x) + ) +end + +function second_derivative(f::F, backend::AbstractADType, x) where {F} + return second_derivative(f, backend, x, prepare_second_derivative(f, backend, x)) +end + +function second_derivative!(f::F, der2, backend::AbstractADType, x) where {F} + return second_derivative!(f, der2, backend, x, prepare_second_derivative(f, backend, x)) +end + function second_derivative( - f::F, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return second_derivative(f, SecondOrder(backend, backend), x, extras) end function second_derivative( - f::F, - backend::SecondOrder, - x, - extras::ClosureSecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::SecondOrder, x, extras::ClosureSecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras return derivative(inner_derivative_closure, outer(backend), x, outer_derivative_extras) end function value_derivative_and_second_derivative( - f::F, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return value_derivative_and_second_derivative( f, SecondOrder(backend, backend), x, extras @@ -100,10 +113,7 @@ function value_derivative_and_second_derivative( end function value_derivative_and_second_derivative( - f::F, - backend::SecondOrder, - x, - extras::ClosureSecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::SecondOrder, x, extras::ClosureSecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras y = f(x) @@ -114,21 +124,13 @@ function value_derivative_and_second_derivative( end function second_derivative!( - f::F, - der2, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der2, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return second_derivative!(f, der2, SecondOrder(backend, backend), x, extras) end function second_derivative!( - f::F, - der2, - backend::SecondOrder, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der2, backend::SecondOrder, x, extras::SecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras return derivative!( @@ -137,12 +139,7 @@ function second_derivative!( end function value_derivative_and_second_derivative!( - f::F, - der, - der2, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der, der2, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return value_derivative_and_second_derivative!( f, der, der2, SecondOrder(backend, backend), x, extras @@ -150,12 +147,7 @@ function value_derivative_and_second_derivative!( end function value_derivative_and_second_derivative!( - f::F, - der, - der2, - backend::SecondOrder, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der, der2, backend::SecondOrder, x, extras::SecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras y = f(x) From 7143ad595024d3d73c3bf027a055f02cb779478e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 07:26:55 +0200 Subject: [PATCH 02/13] Typos --- .../src/first_order/derivative.jl | 26 +++-------------- .../src/first_order/jacobian.jl | 16 +++++------ .../src/first_order/pullback.jl | 28 +++---------------- .../src/first_order/pushforward.jl | 23 ++------------- .../src/second_order/hessian.jl | 14 +++++----- 5 files changed, 26 insertions(+), 81 deletions(-) diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 3baa32338..a78e69c05 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -133,43 +133,25 @@ function derivative!(f!::F, y, der, backend::AbstractADType, x) where {F} end function value_and_derivative( - f!::F, - y, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward(f!, y, backend, x, one(x), extras.pushforward_extras) end function value_and_derivative!( - f!::F, - y, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras) end function derivative( - f!::F, - y, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward(f!, y, backend, x, one(x), extras.pushforward_extras) end function derivative!( - f!::F, - y, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 825071169..523c19fae 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -98,20 +98,20 @@ end ## One argument -function value_and_jacobian(f!::F, y, backend::AbstractADType, x) where {F} - return value_and_jacobian(f!, y, backend, x, prepare_jacobian(f, backend, x)) +function value_and_jacobian(f::F, backend::AbstractADType, x) where {F} + return value_and_jacobian(f, backend, x, prepare_jacobian(f, backend, x)) end -function value_and_jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} - return value_and_jacobian!(f!, y, jac, backend, x, prepare_jacobian(f, backend, x)) +function value_and_jacobian!(f::F, jac, backend::AbstractADType, x) where {F} + return value_and_jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x)) end -function jacobian(f!::F, y, backend::AbstractADType, x) where {F} - return jacobian(f!, y, backend, x, prepare_jacobian(f, backend, x)) +function jacobian(f::F, backend::AbstractADType, x) where {F} + return jacobian(f, backend, x, prepare_jacobian(f, backend, x)) end -function jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} - return jacobian!(f!, y, jac, backend, x, prepare_jacobian(f, backend, x)) +function jacobian!(f::F, jac, backend::AbstractADType, x) where {F} + return jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x)) end function value_and_jacobian( diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index b5f343802..441004656 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -147,11 +147,7 @@ function pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F} end function value_and_pullback( - f::F, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), + f::F, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} return value_and_pullback(f, backend, x, dy, extras) end @@ -178,34 +174,18 @@ function value_and_pullback( end function value_and_pullback!( - f::F, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), + f::F, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} y, new_dx = value_and_pullback(f, backend, x, dy, extras) return y, copyto!(dx, new_dx) end -function pullback( - f::F, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), -) where {F} +function pullback(f::F, backend::AbstractADType, x, dy, extras::PullbackExtras) where {F} return value_and_pullback(f, backend, x, dy, extras)[2] end function pullback!( - f::F, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), + f::F, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} return value_and_pullback!(f, dx, backend, x, dy, extras)[2] end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 4ac318c69..a80b3c4dd 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -231,37 +231,20 @@ function value_and_pushforward( end function value_and_pushforward!( - f!::F, - y, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), + f!::F, y, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} y, new_dy = value_and_pushforward(f!, y, backend, x, dx, extras) return y, copyto!(dy, new_dy) end function pushforward( - f!::F, - y, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), + f!::F, y, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward(f!, y, backend, x, dx, extras)[2] end function pushforward!( - f!::F, - y, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), + f!::F, y, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward!(f!, y, dy, backend, x, dx, extras)[2] end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 90fd64642..70e32502b 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -51,14 +51,14 @@ struct NoHessianExtras <: HessianExtras end struct HVPGradientHessianExtras{E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras hvp_extras::E2 - hessian_extras::E1 + gradient_extras::E1 end function prepare_hessian(f::F, backend::AbstractADType, x) where {F} v = basis(backend, x, first(CartesianIndices(x))) hvp_extras = prepare_hvp(f, backend, x, v) - hessian_extras = prepare_hessian(f, maybe_inner(backend), x) - return HVPGradientHessianExtras(hvp_extras, hessian_extras) + gradient_extras = prepare_gradient(f, maybe_inner(backend), x) + return HVPGradientHessianExtras(hvp_extras, gradient_extras) end ## One argument @@ -107,18 +107,18 @@ function hessian!( return hess end -function value_hessian_and_hessian( +function value_gradient_and_hessian( f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} - y, grad = value_and_hessian(f, maybe_inner(backend), x, extras.hessian_extras) + y, grad = value_and_gradient(f, maybe_inner(backend), x, extras.gradient_extras) hess = hessian(f, backend, x, extras) return y, grad, hess end -function value_hessian_and_hessian!( +function value_gradient_and_hessian!( f::F, grad, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} - y, _ = value_and_hessian!(f, grad, maybe_inner(backend), x, extras.hessian_extras) + y, _ = value_and_gradient!(f, grad, maybe_inner(backend), x, extras.gradient_extras) hessian!(f, hess, backend, x, extras) return y, grad, hess end From 5c50ddb802ca0f795d7ed07b7f400a861d5aee0a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 07:34:11 +0200 Subject: [PATCH 03/13] Fix --- DifferentiationInterface/src/first_order/pullback.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 441004656..7f2fef34b 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -146,12 +146,6 @@ function pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F} return pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy)) end -function value_and_pullback( - f::F, backend::AbstractADType, x, dy, extras::PullbackExtras -) where {F} - return value_and_pullback(f, backend, x, dy, extras) -end - function value_and_pullback( f::F, backend, x, dy, extras::PushforwardPullbackExtras ) where {F} From bbfc13f219187068769db77a1a9b413c192563ad Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:16:59 +0200 Subject: [PATCH 04/13] SecondOrder for hvp --- .../src/second_order/hvp.jl | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 8c078e0b7..1933f3b0e 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -154,27 +154,35 @@ function hvp(f::F, backend::AbstractADType, x, v, extras::HVPExtras) where {F} return hvp(f, SecondOrder(backend, backend), x, v, extras) end -function hvp(f::F, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} +function hvp( + f::F, backend::SecondOrder, x, v, extras::ForwardOverForwardHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward( inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras ) end -function hvp(f::F, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} +function hvp( + f::F, backend::SecondOrder, x, v, extras::ForwardOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward( inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras ) end -function hvp(f::F, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} +function hvp( + f::F, backend::SecondOrder, x, v, extras::ReverseOverForwardHVPExtras +) where {F} @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras inner_pushforward_closure = inner_pushforward_closure_generator(v) return gradient(inner_pushforward_closure, outer(backend), x, outer_gradient_extras) end -function hvp(f::F, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} +function hvp( + f::F, backend::SecondOrder, x, v, extras::ReverseOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pullback_extras) = extras return pullback(inner_gradient_closure, outer(backend), x, v, outer_pullback_extras) end @@ -183,27 +191,35 @@ function hvp!(f::F, p, backend::AbstractADType, x, v, extras::HVPExtras) where { return hvp!(f, p, SecondOrder(backend, backend), x, v, extras) end -function hvp!(f::F, p, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} +function hvp!( + f::F, p, backend::SecondOrder, x, v, extras::ForwardOverForwardHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward!( inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras ) end -function hvp!(f::F, p, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} +function hvp!( + f::F, p, backend::SecondOrder, x, v, extras::ForwardOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward!( inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras ) end -function hvp!(f::F, p, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} +function hvp!( + f::F, p, backend::SecondOrder, x, v, extras::ReverseOverForwardHVPExtras +) where {F} @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras inner_pushforward_closure = inner_pushforward_closure_generator(v) return gradient!(inner_pushforward_closure, p, outer(backend), x, outer_gradient_extras) end -function hvp!(f::F, p, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} +function hvp!( + f::F, p, backend::SecondOrder, x, v, extras::ReverseOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pullback_extras) = extras return pullback!(inner_gradient_closure, p, outer(backend), x, v, outer_pullback_extras) end From 06f3666c8cfc4859280f39f802d4b9de967402e4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:35:26 +0200 Subject: [PATCH 05/13] Polyester --- ...ferentiationInterfacePolyesterForwardDiffExt.jl | 3 ++- .../onearg.jl | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index fa477471a..aa49cc36b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -11,7 +11,8 @@ using DifferentiationInterface: NoGradientExtras, NoHessianExtras, NoJacobianExtras, - PushforwardExtras + PushforwardExtras, + PushforwardDerivativeExtras using DocStringExtensions using LinearAlgebra: mul! using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 99cfa2f2a..6bedd02b4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -31,28 +31,32 @@ end ## Derivative -function DI.prepare_derivative(f, backend::AutoPolyesterForwardDiff, x) +function DI.prepare_derivative( + f, backend::AutoPolyesterForwardDiff, x +)::PushforwardDerivativeExtras return DI.prepare_derivative(f, single_threaded(backend), x) end function DI.value_and_derivative( - f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras ) return DI.value_and_derivative(f, single_threaded(backend), x, extras) end function DI.value_and_derivative!( - f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras ) return DI.value_and_derivative!(f, der, single_threaded(backend), x, extras) end -function DI.derivative(f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras) +function DI.derivative( + f, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras +) return DI.derivative(f, single_threaded(backend), x, extras) end function DI.derivative!( - f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras ) return DI.derivative!(f, der, single_threaded(backend), x, extras) end From a35c8a019ada9db7a6fcd7d5e6b7142f442593f9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 08:42:21 +0200 Subject: [PATCH 06/13] Typo --- DifferentiationInterface/src/second_order/hessian.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 70e32502b..94a8adf2d 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -25,18 +25,18 @@ Compute the Hessian matrix of the function `f` at point `x`, overwriting `hess`. function hessian! end """ - value_hessian_and_hessian(f, backend, x, [extras]) -> (y, grad, hess) + value_gradient_and_hessian(f, backend, x, [extras]) -> (y, grad, hess) Compute the value, hessian vector and Hessian matrix of the function `f` at point `x`. """ -function value_hessian_and_hessian end +function value_gradient_and_hessian end """ - value_hessian_and_hessian!(f, grad, hess, backend, x, [extras]) -> (y, grad, hess) + value_gradient_and_hessian!(f, grad, hess, backend, x, [extras]) -> (y, grad, hess) Compute the value, hessian vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. """ -function value_hessian_and_hessian! end +function value_gradient_and_hessian! end ## Preparation From 042db75e4bc5d9dbd06a225aee713720ca8b74b4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 09:17:51 +0200 Subject: [PATCH 07/13] Typo --- DifferentiationInterface/src/second_order/hvp.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 1933f3b0e..1904f9744 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -143,11 +143,11 @@ end ## One argument function hvp(f::F, backend::AbstractADType, x, v) where {F} - return hvp(f, backend, x, prepare_hvp(f, backend, x, v)) + return hvp(f, backend, x, v, prepare_hvp(f, backend, x, v)) end function hvp!(f::F, p, backend::AbstractADType, x, v) where {F} - return hvp!(f, p, backend, x, prepare_hvp(f, backend, x, v)) + return hvp!(f, p, backend, x, v, prepare_hvp(f, backend, x, v)) end function hvp(f::F, backend::AbstractADType, x, v, extras::HVPExtras) where {F} From 82b2edbdff58a3d4a255e877e3abfa8bed9645d8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 09:32:39 +0200 Subject: [PATCH 08/13] Add correctness tests --- .../src/tests/correctness.jl | 923 +++++++++--------- 1 file changed, 474 insertions(+), 449 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index e83ebadc5..33f1f1eef 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -9,6 +9,8 @@ function test_scen_intact(new_scen, scen) end end +testset_name(k) = k == 1 ? "No prep" : (k == 2 ? "Different point" : "Same point") + ## Pushforward function test_correctness( @@ -26,26 +28,24 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1, dy1 = value_and_pushforward(f, ba, x, dx, extras) - dy2 = pushforward(f, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1 ≈ dy_true - @test dy2 ≈ dy_true - end + y1, dy1 = value_and_pushforward(f, ba, x, dx, extras...) + dy2 = pushforward(f, ba, x, dx, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1 ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -68,31 +68,29 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - dy1_in = mysimilar(y) - y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras) - - dy2_in = mysimilar(y) - dy2 = pushforward!(f, dy2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1_in ≈ dy_true - @test dy1 ≈ dy_true - @test dy2_in ≈ dy_true - @test dy2 ≈ dy_true - end + dy1_in = mysimilar(y) + y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras...) + + dy2_in = mysimilar(y) + dy2 = pushforward!(f, dy2_in, ba, x, dx, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1_in ≈ dy_true + @test dy1 ≈ dy_true + @test dy2_in ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -116,30 +114,28 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in = mysimilar(y) - y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras) - - y2_in = mysimilar(y) - dy2 = pushforward(f!, y2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1 ≈ dy_true - @test dy2 ≈ dy_true - end + y1_in = mysimilar(y) + y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras...) + + y2_in = mysimilar(y) + dy2 = pushforward(f!, y2_in, ba, x, dx, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1 ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -163,32 +159,30 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in, dy1_in = mysimilar(y), mysimilar(y) - y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras) - - y2_in, dy2_in = mysimilar(y), mysimilar(y) - dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1_in ≈ dy_true - @test dy1 ≈ dy_true - @test dy2_in ≈ dy_true - @test dy2 ≈ dy_true - end + y1_in, dy1_in = mysimilar(y), mysimilar(y) + y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras...) + + y2_in, dy2_in = mysimilar(y), mysimilar(y) + dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1_in ≈ dy_true + @test dy1 ≈ dy_true + @test dy2_in ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -213,27 +207,25 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f, ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f, ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1, dx1 = value_and_pullback(f, ba, x, dy, extras) - - dx2 = pullback(f, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1 ≈ dx_true - @test dx2 ≈ dx_true - end + y1, dx1 = value_and_pullback(f, ba, x, dy, extras...) + + dx2 = pullback(f, ba, x, dy, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1 ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -256,31 +248,29 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f, ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f, ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - dx1_in = mysimilar(x) - y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras) - - dx2_in = mysimilar(x) - dx2 = pullback!(f, dx2_in, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1_in ≈ dx_true - @test dx1 ≈ dx_true - @test dx2_in ≈ dx_true - @test dx2 ≈ dx_true - end + dx1_in = mysimilar(x) + y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras...) + + dx2_in = mysimilar(x) + dx2 = pullback!(f, dx2_in, ba, x, dy, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1_in ≈ dx_true + @test dx1 ≈ dx_true + @test dx2_in ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -304,30 +294,28 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in = mysimilar(y) - y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras) - - y2_in = mysimilar(y) - dx2 = pullback(f!, y2_in, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1 ≈ dx_true - @test dx2 ≈ dx_true - end + y1_in = mysimilar(y) + y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras...) + + y2_in = mysimilar(y) + dx2 = pullback(f!, y2_in, ba, x, dy, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1 ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -351,32 +339,30 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in, dx1_in = mysimilar(y), mysimilar(x) - y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras) - - y2_in, dx2_in = mysimilar(y), mysimilar(x) - dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1_in ≈ dx_true - @test dx1 ≈ dx_true - @test dx2_in ≈ dx_true - @test dx2 ≈ dx_true - end + y1_in, dx1_in = mysimilar(y), mysimilar(x) + y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras...) + + y2_in, dx2_in = mysimilar(y), mysimilar(x) + dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1_in ≈ dx_true + @test dx1 ≈ dx_true + @test dx2_in ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -395,27 +381,29 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_derivative(f, ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f, ref_backend, x) else new_scen.ref(x) end - y1, der1 = value_and_derivative(f, ba, x, extras) - - der2 = derivative(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_derivative(f, ba, mycopy_random(x)),) + ]) + y1, der1 = value_and_derivative(f, ba, x, extras...) + der2 = derivative(f, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1 ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa DerivativeExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1 ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -431,31 +419,34 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_derivative(f, ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f, ref_backend, x) else new_scen.ref(x) end - der1_in = mysimilar(y) - y1, der1 = value_and_derivative!(f, der1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_derivative(f, ba, mycopy_random(x)),) + ]) + der1_in = mysimilar(y) + y1, der1 = value_and_derivative!(f, der1_in, ba, x, extras...) - der2_in = mysimilar(y) - der2 = derivative!(f, der2_in, ba, x, extras) + der2_in = mysimilar(y) + der2 = derivative!(f, der2_in, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1_in ≈ der_true - @test der1 ≈ der_true - @test der2_in ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa DerivativeExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1_in ≈ der_true + @test der1 ≈ der_true + @test der2_in ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -472,30 +463,33 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in = mysimilar(y) - y1, der1 = value_and_derivative(f!, y1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in = mysimilar(y) + y1, der1 = value_and_derivative(f!, y1_in, ba, x, extras...) - y2_in = mysimilar(y) - der2 = derivative(f!, y2_in, ba, x, extras) + y2_in = mysimilar(y) + der2 = derivative(f!, y2_in, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1 ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa DerivativeExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1 ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -512,32 +506,35 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in, der1_in = mysimilar(y), mysimilar(y) - y1, der1 = value_and_derivative!(f!, y1_in, der1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in, der1_in = mysimilar(y), mysimilar(y) + y1, der1 = value_and_derivative!(f!, y1_in, der1_in, ba, x, extras...) - y2_in, der2_in = mysimilar(y), mysimilar(y) - der2 = derivative!(f!, y2_in, der2_in, ba, x, extras) + y2_in, der2_in = mysimilar(y), mysimilar(y) + der2 = derivative!(f!, y2_in, der2_in, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1_in ≈ der_true - @test der1 ≈ der_true - @test der2_in ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa DerivativeExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1_in ≈ der_true + @test der1 ≈ der_true + @test der2_in ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -555,27 +552,30 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_gradient(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, ref_backend, x) else new_scen.ref(x) end - y1, grad1 = value_and_gradient(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_gradient(f, ba, mycopy_random(x)),) + ]) + y1, grad1 = value_and_gradient(f, ba, x, extras...) - grad2 = gradient(f, ba, x, extras) + grad2 = gradient(f, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa GradientExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Gradient value" begin - @test grad1 ≈ grad_true - @test grad2 ≈ grad_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa GradientExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Gradient value" begin + @test grad1 ≈ grad_true + @test grad2 ≈ grad_true + end end end test_scen_intact(new_scen, scen) @@ -591,31 +591,34 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_gradient(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, ref_backend, x) else new_scen.ref(x) end - grad1_in = mysimilar(x) - y1, grad1 = value_and_gradient!(f, grad1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_gradient(f, ba, mycopy_random(x)),) + ]) + grad1_in = mysimilar(x) + y1, grad1 = value_and_gradient!(f, grad1_in, ba, x, extras...) - grad2_in = mysimilar(x) - grad2 = gradient!(f, grad2_in, ba, x, extras) + grad2_in = mysimilar(x) + grad2 = gradient!(f, grad2_in, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa GradientExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Gradient value" begin - @test grad1_in ≈ grad_true - @test grad1 ≈ grad_true - @test grad2_in ≈ grad_true - @test grad2 ≈ grad_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa GradientExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Gradient value" begin + @test grad1_in ≈ grad_true + @test grad1 ≈ grad_true + @test grad2_in ≈ grad_true + @test grad2 ≈ grad_true + end end end test_scen_intact(new_scen, scen) @@ -633,27 +636,30 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_jacobian(f, ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f, ref_backend, x) else new_scen.ref(x) end - y1, jac1 = value_and_jacobian(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_jacobian(f, ba, mycopy_random(x)),) + ]) + y1, jac1 = value_and_jacobian(f, ba, x, extras...) - jac2 = jacobian(f, ba, x, extras) + jac2 = jacobian(f, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1 ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa JacobianExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1 ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -669,31 +675,34 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_jacobian(f, ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f, ref_backend, x) else new_scen.ref(x) end - jac1_in = mysimilar(jac_true) - y1, jac1 = value_and_jacobian!(f, jac1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_jacobian(f, ba, mycopy_random(x)),) + ]) + jac1_in = mysimilar(jac_true) + y1, jac1 = value_and_jacobian!(f, jac1_in, ba, x, extras...) - jac2_in = mysimilar(jac_true) - jac2 = jacobian!(f, jac2_in, ba, x, extras) + jac2_in = mysimilar(jac_true) + jac2 = jacobian!(f, jac2_in, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1_in ≈ jac_true - @test jac1 ≈ jac_true - @test jac2_in ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa JacobianExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1_in ≈ jac_true + @test jac1 ≈ jac_true + @test jac2_in ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -710,30 +719,33 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in = mysimilar(y) - y1, jac1 = value_and_jacobian(f!, y1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in = mysimilar(y) + y1, jac1 = value_and_jacobian(f!, y1_in, ba, x, extras...) - y2_in = mysimilar(y) - jac2 = jacobian(f!, y2_in, ba, x, extras) + y2_in = mysimilar(y) + jac2 = jacobian(f!, y2_in, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1 ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa JacobianExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1 ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -750,32 +762,35 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in, jac1_in = mysimilar(y), mysimilar(jac_true) - y1, jac1 = value_and_jacobian!(f!, y1_in, jac1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in, jac1_in = mysimilar(y), mysimilar(jac_true) + y1, jac1 = value_and_jacobian!(f!, y1_in, jac1_in, ba, x, extras...) - y2_in, jac2_in = mysimilar(y), mysimilar(jac_true) - jac2 = jacobian!(f!, y2_in, jac2_in, ba, x, extras) + y2_in, jac2_in = mysimilar(y), mysimilar(jac_true) + jac2 = jacobian!(f!, y2_in, jac2_in, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1_in ≈ jac_true - @test jac1 ≈ jac_true - @test jac2_in ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa JacobianExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1_in ≈ jac_true + @test jac1 ≈ jac_true + @test jac2_in ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -793,7 +808,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_second_derivative(f, ba, mycopy_random(x)) der1_true = if ref_backend isa AbstractADType derivative(f, maybe_inner(ref_backend), x) else @@ -805,22 +819,26 @@ function test_correctness( new_scen.ref(x) end - der21 = second_derivative(f, ba, x, extras) - y2, der12, der22 = value_derivative_and_second_derivative(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_second_derivative(f, ba, mycopy_random(x)),) + ]) + der21 = second_derivative(f, ba, x, extras...) + y2, der12, der22 = value_derivative_and_second_derivative(f, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa SecondDerivativeExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "First derivative value" begin - @test der12 ≈ der1_true - end - @testset "Second derivative value" begin - @test der21 ≈ der2_true - @test der22 ≈ der2_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa SecondDerivativeExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "First derivative value" begin + @test der12 ≈ der1_true + end + @testset "Second derivative value" begin + @test der21 ≈ der2_true + @test der22 ≈ der2_true + end end end test_scen_intact(new_scen, scen) @@ -836,7 +854,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_second_derivative(f, ba, mycopy_random(x)) der1_true = if ref_backend isa AbstractADType derivative(f, maybe_inner(ref_backend), x) else @@ -848,30 +865,34 @@ function test_correctness( new_scen.ref(x) end - der21_in = mysimilar(y) - der21 = second_derivative!(f, der21_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_second_derivative(f, ba, mycopy_random(x)),) + ]) + der21_in = mysimilar(y) + der21 = second_derivative!(f, der21_in, ba, x, extras...) - der12_in, der22_in = mysimilar(y), mysimilar(y) - y2, der12, der22 = value_derivative_and_second_derivative!( - f, der12_in, der22_in, ba, x, extras - ) + der12_in, der22_in = mysimilar(y), mysimilar(y) + y2, der12, der22 = value_derivative_and_second_derivative!( + f, der12_in, der22_in, ba, x, extras + ) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa SecondDerivativeExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "Derivative value" begin - @test der12_in ≈ der1_true - @test der12 ≈ der1_true - end - @testset "Second derivative value" begin - @test der21_in ≈ der2_true - @test der22_in ≈ der2_true - @test der21 ≈ der2_true - @test der22 ≈ der2_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa SecondDerivativeExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Derivative value" begin + @test der12_in ≈ der1_true + @test der12 ≈ der1_true + end + @testset "Second derivative value" begin + @test der21_in ≈ der2_true + @test der22_in ≈ der2_true + @test der21 ≈ der2_true + @test der22 ≈ der2_true + end end end test_scen_intact(new_scen, scen) @@ -895,21 +916,19 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_hvp_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - p1 = hvp(f, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HVPExtras - end - @testset "HVP value" begin - @test p1 ≈ p_true - end + p1 = hvp(f, ba, x, dx, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa HVPExtras + end + @testset "HVP value" begin + @test p1 ≈ p_true end end end @@ -932,23 +951,21 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_hvp_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), + (prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - p1_in = mysimilar(x) - p1 = hvp!(f, p1_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HVPExtras - end - @testset "HVP value" begin - @test p1_in ≈ p_true - @test p1 ≈ p_true - end + p1_in = mysimilar(x) + p1 = hvp!(f, p1_in, ba, x, dx, extras...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa HVPExtras + end + @testset "HVP value" begin + @test p1_in ≈ p_true + @test p1 ≈ p_true end end end @@ -967,7 +984,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_hessian(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else @@ -979,22 +995,26 @@ function test_correctness( new_scen.ref(x) end - hess1 = hessian(f, ba, x, extras) - y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_hessian(f, ba, mycopy_random(x)),) + ]) + hess1 = hessian(f, ba, x, extras...) + y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HessianExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "Gradient value" begin - @test grad2 ≈ grad_true - end - @testset "Hessian value" begin - @test hess1 ≈ hess_true - @test hess2 ≈ hess_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa HessianExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Gradient value" begin + @test grad2 ≈ grad_true + end + @testset "Hessian value" begin + @test hess1 ≈ hess_true + @test hess2 ≈ hess_true + end end end test_scen_intact(new_scen, scen) @@ -1010,7 +1030,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_hessian(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else @@ -1022,27 +1041,33 @@ function test_correctness( new_scen.ref(x) end - hess1_in = mysimilar(hess_true) - hess1 = hessian!(f, hess1_in, ba, x, extras) - grad2_in, hess2_in = mysimilar(grad_true), mysimilar(hess_true) - y2, grad2, hess2 = value_gradient_and_hessian!(f, grad2_in, hess2_in, ba, x, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HessianExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "Gradient value" begin - @test grad2_in ≈ grad_true - @test grad2 ≈ grad_true - end - @testset "Hessian value" begin - @test hess1_in ≈ hess_true - @test hess2_in ≈ hess_true - @test hess1 ≈ hess_true - @test hess2 ≈ hess_true + @testset "$(testset_name(k))" for (k, extras) in enumerate([ + (), (prepare_hessian(f, ba, mycopy_random(x)),) + ]) + hess1_in = mysimilar(hess_true) + hess1 = hessian!(f, hess1_in, ba, x, extras...) + grad2_in, hess2_in = mysimilar(grad_true), mysimilar(hess_true) + y2, grad2, hess2 = value_gradient_and_hessian!( + f, grad2_in, hess2_in, ba, x, extras... + ) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa HessianExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Gradient value" begin + @test grad2_in ≈ grad_true + @test grad2 ≈ grad_true + end + @testset "Hessian value" begin + @test hess1_in ≈ hess_true + @test hess2_in ≈ hess_true + @test hess1 ≈ hess_true + @test hess2 ≈ hess_true + end end end test_scen_intact(new_scen, scen) From 1393aa19ba0854ee293a0ceb83382090f33899f9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 09:37:53 +0200 Subject: [PATCH 09/13] Extras type --- .../src/tests/correctness.jl | 188 +++++++++--------- 1 file changed, 94 insertions(+), 94 deletions(-) diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index 33f1f1eef..be33068d9 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -28,17 +28,17 @@ function test_correctness( new_scen.ref(x, dx) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)),), (prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)),), ]) - y1, dy1 = value_and_pushforward(f, ba, x, dx, extras...) - dy2 = pushforward(f, ba, x, dx, extras...) + y1, dy1 = value_and_pushforward(f, ba, x, dx, extras_tup...) + dy2 = pushforward(f, ba, x, dx, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PushforwardExtras + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras end @testset "Primal value" begin @test y1 ≈ y @@ -68,20 +68,20 @@ function test_correctness( new_scen.ref(x, dx) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)),), (prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)),), ]) dy1_in = mysimilar(y) - y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras...) + y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras_tup...) dy2_in = mysimilar(y) - dy2 = pushforward!(f, dy2_in, ba, x, dx, extras...) + dy2 = pushforward!(f, dy2_in, ba, x, dx, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PushforwardExtras + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras end @testset "Primal value" begin @test y1 ≈ y @@ -114,20 +114,20 @@ function test_correctness( new_scen.ref(x, dx) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)),), (prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)),), ]) y1_in = mysimilar(y) - y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras...) + y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras_tup...) y2_in = mysimilar(y) - dy2 = pushforward(f!, y2_in, ba, x, dx, extras...) + dy2 = pushforward(f!, y2_in, ba, x, dx, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PushforwardExtras + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -159,20 +159,20 @@ function test_correctness( new_scen.ref(x, dx) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)),), (prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)),), ]) y1_in, dy1_in = mysimilar(y), mysimilar(y) - y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras...) + y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras_tup...) y2_in, dy2_in = mysimilar(y), mysimilar(y) - dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras...) + dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PushforwardExtras + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -207,18 +207,18 @@ function test_correctness( new_scen.ref(x, dy) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)),), (prepare_pullback_same_point(f, ba, x, mycopy_random(dy)),), ]) - y1, dx1 = value_and_pullback(f, ba, x, dy, extras...) + y1, dx1 = value_and_pullback(f, ba, x, dy, extras_tup...) - dx2 = pullback(f, ba, x, dy, extras...) + dx2 = pullback(f, ba, x, dy, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PullbackExtras + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras end @testset "Primal value" begin @test y1 ≈ y @@ -248,20 +248,20 @@ function test_correctness( new_scen.ref(x, dy) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)),), (prepare_pullback_same_point(f, ba, x, mycopy_random(dy)),), ]) dx1_in = mysimilar(x) - y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras...) + y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras_tup...) dx2_in = mysimilar(x) - dx2 = pullback!(f, dx2_in, ba, x, dy, extras...) + dx2 = pullback!(f, dx2_in, ba, x, dy, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PullbackExtras + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras end @testset "Primal value" begin @test y1 ≈ y @@ -294,20 +294,20 @@ function test_correctness( new_scen.ref(x, dy) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)),), (prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)),), ]) y1_in = mysimilar(y) - y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras...) + y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras_tup...) y2_in = mysimilar(y) - dx2 = pullback(f!, y2_in, ba, x, dy, extras...) + dx2 = pullback(f!, y2_in, ba, x, dy, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PullbackExtras + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -339,20 +339,20 @@ function test_correctness( new_scen.ref(x, dy) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)),), (prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)),), ]) y1_in, dx1_in = mysimilar(y), mysimilar(x) - y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras...) + y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras_tup...) y2_in, dx2_in = mysimilar(y), mysimilar(x) - dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras...) + dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa PullbackExtras + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -387,15 +387,15 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_derivative(f, ba, mycopy_random(x)),) ]) - y1, der1 = value_and_derivative(f, ba, x, extras...) - der2 = derivative(f, ba, x, extras...) + y1, der1 = value_and_derivative(f, ba, x, extras_tup...) + der2 = derivative(f, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa DerivativeExtras + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras end @testset "Primal value" begin @test y1 ≈ y @@ -425,18 +425,18 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_derivative(f, ba, mycopy_random(x)),) ]) der1_in = mysimilar(y) - y1, der1 = value_and_derivative!(f, der1_in, ba, x, extras...) + y1, der1 = value_and_derivative!(f, der1_in, ba, x, extras_tup...) der2_in = mysimilar(y) - der2 = derivative!(f, der2_in, ba, x, extras...) + der2 = derivative!(f, der2_in, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa DerivativeExtras + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras end @testset "Primal value" begin @test y1 ≈ y @@ -469,18 +469,18 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)),) ]) y1_in = mysimilar(y) - y1, der1 = value_and_derivative(f!, y1_in, ba, x, extras...) + y1, der1 = value_and_derivative(f!, y1_in, ba, x, extras_tup...) y2_in = mysimilar(y) - der2 = derivative(f!, y2_in, ba, x, extras...) + der2 = derivative(f!, y2_in, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa DerivativeExtras + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -512,18 +512,18 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)),) ]) y1_in, der1_in = mysimilar(y), mysimilar(y) - y1, der1 = value_and_derivative!(f!, y1_in, der1_in, ba, x, extras...) + y1, der1 = value_and_derivative!(f!, y1_in, der1_in, ba, x, extras_tup...) y2_in, der2_in = mysimilar(y), mysimilar(y) - der2 = derivative!(f!, y2_in, der2_in, ba, x, extras...) + der2 = derivative!(f!, y2_in, der2_in, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa DerivativeExtras + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -558,16 +558,16 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_gradient(f, ba, mycopy_random(x)),) ]) - y1, grad1 = value_and_gradient(f, ba, x, extras...) + y1, grad1 = value_and_gradient(f, ba, x, extras_tup...) - grad2 = gradient(f, ba, x, extras...) + grad2 = gradient(f, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa GradientExtras + @test isempty(extras_tup) || only(extras_tup) isa GradientExtras end @testset "Primal value" begin @test y1 ≈ y @@ -597,18 +597,18 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_gradient(f, ba, mycopy_random(x)),) ]) grad1_in = mysimilar(x) - y1, grad1 = value_and_gradient!(f, grad1_in, ba, x, extras...) + y1, grad1 = value_and_gradient!(f, grad1_in, ba, x, extras_tup...) grad2_in = mysimilar(x) - grad2 = gradient!(f, grad2_in, ba, x, extras...) + grad2 = gradient!(f, grad2_in, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa GradientExtras + @test isempty(extras_tup) || only(extras_tup) isa GradientExtras end @testset "Primal value" begin @test y1 ≈ y @@ -642,16 +642,16 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_jacobian(f, ba, mycopy_random(x)),) ]) - y1, jac1 = value_and_jacobian(f, ba, x, extras...) + y1, jac1 = value_and_jacobian(f, ba, x, extras_tup...) - jac2 = jacobian(f, ba, x, extras...) + jac2 = jacobian(f, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa JacobianExtras + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras end @testset "Primal value" begin @test y1 ≈ y @@ -681,18 +681,18 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_jacobian(f, ba, mycopy_random(x)),) ]) jac1_in = mysimilar(jac_true) - y1, jac1 = value_and_jacobian!(f, jac1_in, ba, x, extras...) + y1, jac1 = value_and_jacobian!(f, jac1_in, ba, x, extras_tup...) jac2_in = mysimilar(jac_true) - jac2 = jacobian!(f, jac2_in, ba, x, extras...) + jac2 = jacobian!(f, jac2_in, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa JacobianExtras + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras end @testset "Primal value" begin @test y1 ≈ y @@ -725,18 +725,18 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)),) ]) y1_in = mysimilar(y) - y1, jac1 = value_and_jacobian(f!, y1_in, ba, x, extras...) + y1, jac1 = value_and_jacobian(f!, y1_in, ba, x, extras_tup...) y2_in = mysimilar(y) - jac2 = jacobian(f!, y2_in, ba, x, extras...) + jac2 = jacobian(f!, y2_in, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa JacobianExtras + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -768,18 +768,18 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)),) ]) y1_in, jac1_in = mysimilar(y), mysimilar(jac_true) - y1, jac1 = value_and_jacobian!(f!, y1_in, jac1_in, ba, x, extras...) + y1, jac1 = value_and_jacobian!(f!, y1_in, jac1_in, ba, x, extras_tup...) y2_in, jac2_in = mysimilar(y), mysimilar(jac_true) - jac2 = jacobian!(f!, y2_in, jac2_in, ba, x, extras...) + jac2 = jacobian!(f!, y2_in, jac2_in, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa JacobianExtras + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras end @testset "Primal value" begin @test y1_in ≈ y @@ -819,15 +819,15 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_second_derivative(f, ba, mycopy_random(x)),) ]) - der21 = second_derivative(f, ba, x, extras...) - y2, der12, der22 = value_derivative_and_second_derivative(f, ba, x, extras...) + der21 = second_derivative(f, ba, x, extras_tup...) + y2, der12, der22 = value_derivative_and_second_derivative(f, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa SecondDerivativeExtras + @test isempty(extras_tup) || only(extras_tup) isa SecondDerivativeExtras end @testset "Primal value" begin @test y2 ≈ y @@ -865,20 +865,20 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_second_derivative(f, ba, mycopy_random(x)),) ]) der21_in = mysimilar(y) - der21 = second_derivative!(f, der21_in, ba, x, extras...) + der21 = second_derivative!(f, der21_in, ba, x, extras_tup...) der12_in, der22_in = mysimilar(y), mysimilar(y) y2, der12, der22 = value_derivative_and_second_derivative!( - f, der12_in, der22_in, ba, x, extras + f, der12_in, der22_in, ba, x, extras_tup ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa SecondDerivativeExtras + @test isempty(extras_tup) || only(extras_tup) isa SecondDerivativeExtras end @testset "Primal value" begin @test y2 ≈ y @@ -916,16 +916,16 @@ function test_correctness( new_scen.ref(x, dx) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),), (prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),), ]) - p1 = hvp(f, ba, x, dx, extras...) + p1 = hvp(f, ba, x, dx, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa HVPExtras + @test isempty(extras_tup) || only(extras_tup) isa HVPExtras end @testset "HVP value" begin @test p1 ≈ p_true @@ -951,17 +951,17 @@ function test_correctness( new_scen.ref(x, dx) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),), (prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),), ]) p1_in = mysimilar(x) - p1 = hvp!(f, p1_in, ba, x, dx, extras...) + p1 = hvp!(f, p1_in, ba, x, dx, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa HVPExtras + @test isempty(extras_tup) || only(extras_tup) isa HVPExtras end @testset "HVP value" begin @test p1_in ≈ p_true @@ -995,15 +995,15 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_hessian(f, ba, mycopy_random(x)),) ]) - hess1 = hessian(f, ba, x, extras...) - y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras...) + hess1 = hessian(f, ba, x, extras_tup...) + y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras_tup...) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa HessianExtras + @test isempty(extras_tup) || only(extras_tup) isa HessianExtras end @testset "Primal value" begin @test y2 ≈ y @@ -1041,19 +1041,19 @@ function test_correctness( new_scen.ref(x) end - @testset "$(testset_name(k))" for (k, extras) in enumerate([ + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ (), (prepare_hessian(f, ba, mycopy_random(x)),) ]) hess1_in = mysimilar(hess_true) - hess1 = hessian!(f, hess1_in, ba, x, extras...) + hess1 = hessian!(f, hess1_in, ba, x, extras_tup...) grad2_in, hess2_in = mysimilar(grad_true), mysimilar(hess_true) y2, grad2, hess2 = value_gradient_and_hessian!( - f, grad2_in, hess2_in, ba, x, extras... + f, grad2_in, hess2_in, ba, x, extras_tup... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin - @test extras isa HessianExtras + @test isempty(extras_tup) || only(extras_tup) isa HessianExtras end @testset "Primal value" begin @test y2 ≈ y From f6ec504ff9751226da2db44e010279e192e6cdc7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 09:46:34 +0200 Subject: [PATCH 10/13] Typo --- DifferentiationInterface/src/first_order/pullback.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 7f2fef34b..cb41543a4 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -135,7 +135,7 @@ function value_and_pullback(f::F, backend::AbstractADType, x, dy) where {F} end function value_and_pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F} - return value_and_pullback!(f, dx, backend, dy, x, prepare_pullback(f, backend, x, dy)) + return value_and_pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy)) end function pullback(f::F, backend::AbstractADType, x, dy) where {F} From 800984c7606eafc2108962e68e60f23e591b6f72 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 09:54:48 +0200 Subject: [PATCH 11/13] Other typo --- DifferentiationInterface/src/first_order/pullback.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index cb41543a4..9b464a085 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -194,7 +194,7 @@ end function value_and_pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F} return value_and_pullback!( - f!, y, dx, backend, dy, x, prepare_pullback(f!, y, backend, x, dy) + f!, y, dx, backend, x, dy, prepare_pullback(f!, y, backend, x, dy) ) end From 8b62e4b5c4f3475152cc7691f7f873b52f06bcaf Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 10:02:30 +0200 Subject: [PATCH 12/13] Typo --- DifferentiationInterfaceTest/src/tests/correctness.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index be33068d9..f937530c2 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -873,7 +873,7 @@ function test_correctness( der12_in, der22_in = mysimilar(y), mysimilar(y) y2, der12, der22 = value_derivative_and_second_derivative!( - f, der12_in, der22_in, ba, x, extras_tup + f, der12_in, der22_in, ba, x, extras_tup... ) let (≈)(x, y) = isapprox(x, y; atol, rtol) From 4c3bbb0002170b8840273798d2ff6ec05addfab7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 7 Jun 2024 10:18:31 +0200 Subject: [PATCH 13/13] Hessian doc --- DifferentiationInterface/src/second_order/hessian.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 94a8adf2d..4501dd817 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -27,14 +27,14 @@ function hessian! end """ value_gradient_and_hessian(f, backend, x, [extras]) -> (y, grad, hess) -Compute the value, hessian vector and Hessian matrix of the function `f` at point `x`. +Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`. """ function value_gradient_and_hessian end """ value_gradient_and_hessian!(f, grad, hess, backend, x, [extras]) -> (y, grad, hess) -Compute the value, hessian vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. +Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. """ function value_gradient_and_hessian! end