diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl index fa477471a..aa49cc36b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl @@ -11,7 +11,8 @@ using DifferentiationInterface: NoGradientExtras, NoHessianExtras, NoJacobianExtras, - PushforwardExtras + PushforwardExtras, + PushforwardDerivativeExtras using DocStringExtensions using LinearAlgebra: mul! using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian! diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 99cfa2f2a..6bedd02b4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -31,28 +31,32 @@ end ## Derivative -function DI.prepare_derivative(f, backend::AutoPolyesterForwardDiff, x) +function DI.prepare_derivative( + f, backend::AutoPolyesterForwardDiff, x +)::PushforwardDerivativeExtras return DI.prepare_derivative(f, single_threaded(backend), x) end function DI.value_and_derivative( - f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras ) return DI.value_and_derivative(f, single_threaded(backend), x, extras) end function DI.value_and_derivative!( - f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras ) return DI.value_and_derivative!(f, der, single_threaded(backend), x, extras) end -function DI.derivative(f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras) +function DI.derivative( + f, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras +) return DI.derivative(f, single_threaded(backend), x, extras) end function DI.derivative!( - f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras ) return DI.derivative!(f, der, single_threaded(backend), x, extras) end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 261d263b5..a78e69c05 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -72,84 +72,86 @@ end ## One argument +function value_and_derivative(f::F, backend::AbstractADType, x) where {F} + return value_and_derivative(f, backend, x, prepare_derivative(f, backend, x)) +end + +function value_and_derivative!(f::F, der, backend::AbstractADType, x) where {F} + return value_and_derivative!(f, der, backend, x, prepare_derivative(f, backend, x)) +end + +function derivative(f::F, backend::AbstractADType, x) where {F} + return derivative(f, backend, x, prepare_derivative(f, backend, x)) +end + +function derivative!(f::F, der, backend::AbstractADType, x) where {F} + return derivative!(f, der, backend, x, prepare_derivative(f, backend, x)) +end + function value_and_derivative( - f::F, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward(f, backend, x, one(x), extras.pushforward_extras) end function value_and_derivative!( - f::F, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward!(f, der, backend, x, one(x), extras.pushforward_extras) end function derivative( - f::F, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward(f, backend, x, one(x), extras.pushforward_extras) end function derivative!( - f::F, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f, backend, x), + f::F, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward!(f, der, backend, x, one(x), extras.pushforward_extras) end ## Two arguments +function value_and_derivative(f!::F, y, backend::AbstractADType, x) where {F} + return value_and_derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x)) +end + +function value_and_derivative!(f!::F, y, der, backend::AbstractADType, x) where {F} + return value_and_derivative!( + f!, y, der, backend, x, prepare_derivative(f!, y, backend, x) + ) +end + +function derivative(f!::F, y, backend::AbstractADType, x) where {F} + return derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x)) +end + +function derivative!(f!::F, y, der, backend::AbstractADType, x) where {F} + return derivative!(f!, y, der, backend, x, prepare_derivative(f!, y, backend, x)) +end + function value_and_derivative( - f!::F, - y, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward(f!, y, backend, x, one(x), extras.pushforward_extras) end function value_and_derivative!( - f!::F, - y, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return value_and_pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras) end function derivative( - f!::F, - y, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward(f!, y, backend, x, one(x), extras.pushforward_extras) end function derivative!( - f!::F, - y, - der, - backend::AbstractADType, - x, - extras::DerivativeExtras=prepare_derivative(f!, y, backend, x), + f!::F, y, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras ) where {F} return pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras) end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 35fcbd41a..aab866768 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -62,34 +62,42 @@ end ## One argument +function value_and_gradient(f::F, backend::AbstractADType, x) where {F} + return value_and_gradient(f, backend, x, prepare_gradient(f, backend, x)) +end + +function value_and_gradient!(f::F, der, backend::AbstractADType, x) where {F} + return value_and_gradient!(f, der, backend, x, prepare_gradient(f, backend, x)) +end + +function gradient(f::F, backend::AbstractADType, x) where {F} + return gradient(f, backend, x, prepare_gradient(f, backend, x)) +end + +function gradient!(f::F, der, backend::AbstractADType, x) where {F} + return gradient!(f, der, backend, x, prepare_gradient(f, backend, x)) +end + function value_and_gradient( - f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x) + f::F, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return value_and_pullback(f, backend, x, one(eltype(x)), extras.pullback_extras) end function value_and_gradient!( - f::F, - grad, - backend::AbstractADType, - x, - extras::GradientExtras=prepare_gradient(f, backend, x), + f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return value_and_pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras) end function gradient( - f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x) + f::F, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return pullback(f, backend, x, one(eltype(x)), extras.pullback_extras) end function gradient!( - f::F, - grad, - backend::AbstractADType, - x, - extras::GradientExtras=prepare_gradient(f, backend, x), + f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} return pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 9847a02fc..523c19fae 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -98,13 +98,23 @@ end ## One argument -function value_and_jacobian( - f::F, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x) -) where {F} - return value_and_jacobian_onearg_aux(f, backend, x, extras) +function value_and_jacobian(f::F, backend::AbstractADType, x) where {F} + return value_and_jacobian(f, backend, x, prepare_jacobian(f, backend, x)) +end + +function value_and_jacobian!(f::F, jac, backend::AbstractADType, x) where {F} + return value_and_jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x)) +end + +function jacobian(f::F, backend::AbstractADType, x) where {F} + return jacobian(f, backend, x, prepare_jacobian(f, backend, x)) +end + +function jacobian!(f::F, jac, backend::AbstractADType, x) where {F} + return jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x)) end -function value_and_jacobian_onearg_aux( +function value_and_jacobian( f::F, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -123,7 +133,7 @@ function value_and_jacobian_onearg_aux( return y, jac end -function value_and_jacobian_onearg_aux( +function value_and_jacobian( f::F, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -139,16 +149,6 @@ function value_and_jacobian_onearg_aux( end function value_and_jacobian!( - f::F, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f, backend, x), -) where {F} - return value_and_jacobian_onearg_aux!(f, jac, backend, x, extras) -end - -function value_and_jacobian_onearg_aux!( f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -167,7 +167,7 @@ function value_and_jacobian_onearg_aux!( return y, jac end -function value_and_jacobian_onearg_aux!( +function value_and_jacobian!( f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} y = f(x) # TODO: remove @@ -182,35 +182,33 @@ function value_and_jacobian_onearg_aux!( return y, jac end -function jacobian( - f::F, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x) -) where {F} +function jacobian(f::F, backend::AbstractADType, x, extras::JacobianExtras) where {F} return value_and_jacobian(f, backend, x, extras)[2] end -function jacobian!( - f::F, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f, backend, x), -) where {F} +function jacobian!(f::F, jac, backend::AbstractADType, x, extras::JacobianExtras) where {F} return value_and_jacobian!(f, jac, backend, x, extras)[2] end ## Two arguments -function value_and_jacobian( - f!::F, - y, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), -) where {F} - return value_and_jacobian_twoarg_aux(f!, y, backend, x, extras) +function value_and_jacobian(f!::F, y, backend::AbstractADType, x) where {F} + return value_and_jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x)) +end + +function value_and_jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} + return value_and_jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x)) +end + +function jacobian(f!::F, y, backend::AbstractADType, x) where {F} + return jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x)) end -function value_and_jacobian_twoarg_aux( +function jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} + return jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x)) +end + +function value_and_jacobian( f!::F, y, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} pushforward_extras_same = prepare_pushforward_same_point( @@ -230,7 +228,7 @@ function value_and_jacobian_twoarg_aux( return y, jac end -function value_and_jacobian_twoarg_aux( +function value_and_jacobian( f!::F, y, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} pullback_extras_same = prepare_pullback_same_point( @@ -251,17 +249,6 @@ function value_and_jacobian_twoarg_aux( end function value_and_jacobian!( - f!::F, - y, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), -) where {F} - return value_and_jacobian_twoarg_aux!(f!, y, jac, backend, x, extras) -end - -function value_and_jacobian_twoarg_aux!( f!::F, y, jac::AbstractMatrix, @@ -286,7 +273,7 @@ function value_and_jacobian_twoarg_aux!( return y, jac end -function value_and_jacobian_twoarg_aux!( +function value_and_jacobian!( f!::F, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} pullback_extras_same = prepare_pullback_same_point( @@ -306,23 +293,12 @@ function value_and_jacobian_twoarg_aux!( return y, jac end -function jacobian( - f!::F, - y, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), -) where {F} +function jacobian(f!::F, y, backend::AbstractADType, x, extras::JacobianExtras) where {F} return value_and_jacobian(f!, y, backend, x, extras)[2] end function jacobian!( - f!::F, - y, - jac, - backend::AbstractADType, - x, - extras::JacobianExtras=prepare_jacobian(f!, y, backend, x), + f!::F, y, jac, backend::AbstractADType, x, extras::JacobianExtras ) where {F} return value_and_jacobian!(f!, y, jac, backend, x, extras)[2] end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index ed7119cd9..9b464a085 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -95,8 +95,14 @@ function prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackSlow) where {F end # Throw error if backend is missing -prepare_pullback_aux(f::F, backend, x, dy, ::PullbackFast) where {F} = throw(MissingBackendError(backend)) -prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackFast) where {F} = throw(MissingBackendError(backend)) + +function prepare_pullback_aux(f, backend, x, dy, ::PullbackFast) + throw(MissingBackendError(backend)) +end + +function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast) + throw(MissingBackendError(backend)) +end ## Preparation (same point) @@ -124,17 +130,23 @@ end ## One argument -function value_and_pullback( - f::F, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), -) where {F} - return value_and_pullback_onearg_aux(f, backend, x, dy, extras) +function value_and_pullback(f::F, backend::AbstractADType, x, dy) where {F} + return value_and_pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy)) end -function value_and_pullback_onearg_aux( +function value_and_pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F} + return value_and_pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy)) +end + +function pullback(f::F, backend::AbstractADType, x, dy) where {F} + return pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy)) +end + +function pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F} + return pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy)) +end + +function value_and_pullback( f::F, backend, x, dy, extras::PushforwardPullbackExtras ) where {F} @compat (; pushforward_extras) = extras @@ -156,52 +168,45 @@ function value_and_pullback_onearg_aux( end function value_and_pullback!( - f::F, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), + f::F, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} y, new_dx = value_and_pullback(f, backend, x, dy, extras) return y, copyto!(dx, new_dx) end -function pullback( - f::F, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), -) where {F} +function pullback(f::F, backend::AbstractADType, x, dy, extras::PullbackExtras) where {F} return value_and_pullback(f, backend, x, dy, extras)[2] end function pullback!( - f::F, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f, backend, x, dy), + f::F, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} return value_and_pullback!(f, dx, backend, x, dy, extras)[2] end ## Two arguments -function value_and_pullback( - f!::F, - y, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), -) where {F} - return value_and_pullback_twoarg_aux(f!, y, backend, x, dy, extras) +function value_and_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F} + return value_and_pullback( + f!, y, backend, x, dy, prepare_pullback(f!, y, backend, x, dy) + ) +end + +function value_and_pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F} + return value_and_pullback!( + f!, y, dx, backend, x, dy, prepare_pullback(f!, y, backend, x, dy) + ) end -function value_and_pullback_twoarg_aux( +function pullback(f!::F, y, backend::AbstractADType, x, dy) where {F} + return pullback(f!, y, backend, x, dy, prepare_pullback(f!, y, backend, x, dy)) +end + +function pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F} + return pullback!(f!, y, dx, backend, x, dy, prepare_pullback(f!, y, backend, x, dy)) +end + +function value_and_pullback( f!::F, y, backend, x, dy, extras::PushforwardPullbackExtras ) where {F} @compat (; pushforward_extras) = extras @@ -217,37 +222,20 @@ function value_and_pullback_twoarg_aux( end function value_and_pullback!( - f!::F, - y, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), + f!::F, y, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} y, new_dx = value_and_pullback(f!, y, backend, x, dy, extras) return y, copyto!(dx, new_dx) end function pullback( - f!::F, - y, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), + f!::F, y, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} return value_and_pullback(f!, y, backend, x, dy, extras)[2] end function pullback!( - f!::F, - y, - dx, - backend::AbstractADType, - x, - dy, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, dy), + f!::F, y, dx, backend::AbstractADType, x, dy, extras::PullbackExtras ) where {F} return value_and_pullback!(f!, y, dx, backend, x, dy, extras)[2] end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index db1a6784d..a80b3c4dd 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -96,8 +96,14 @@ function prepare_pushforward_aux(f!::F, y, backend, x, dx, ::PushforwardSlow) wh end # Throw error if backend is missing -prepare_pushforward_aux(f::F, backend, x, dy, ::PushforwardFast) where {F} = throw(MissingBackendError(backend)) -prepare_pushforward_aux(f!::F, y, backend, x, dy, ::PushforwardFast) where {F} = throw(MissingBackendError(backend)) + +function prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast) + throw(MissingBackendError(backend)) +end + +function prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast) + throw(MissingBackendError(backend)) +end ## Preparation (same point) @@ -125,17 +131,25 @@ end ## One argument -function value_and_pushforward( - f::F, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), -) where {F} - return value_and_pushforward_onearg_aux(f, backend, x, dx, extras) +function value_and_pushforward(f::F, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward(f, backend, x, dx, prepare_pushforward(f, backend, x, dx)) +end + +function value_and_pushforward!(f::F, dy, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward!( + f, dy, backend, x, dx, prepare_pushforward(f, backend, x, dx) + ) +end + +function pushforward(f::F, backend::AbstractADType, x, dx) where {F} + return pushforward(f, backend, x, dx, prepare_pushforward(f, backend, x, dx)) +end + +function pushforward!(f::F, dy, backend::AbstractADType, x, dx) where {F} + return pushforward!(f, dy, backend, x, dx, prepare_pushforward(f, backend, x, dx)) end -function value_and_pushforward_onearg_aux( +function value_and_pushforward( f::F, backend, x, dx, extras::PullbackPushforwardExtras ) where {F} @compat (; pullback_extras) = extras @@ -157,52 +171,49 @@ function value_and_pushforward_onearg_aux( end function value_and_pushforward!( - f::F, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), + f::F, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} y, new_dy = value_and_pushforward(f, backend, x, dx, extras) return y, copyto!(dy, new_dy) end function pushforward( - f::F, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), + f::F, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward(f, backend, x, dx, extras)[2] end function pushforward!( - f::F, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f, backend, x, dx), + f::F, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward!(f, dy, backend, x, dx, extras)[2] end ## Two arguments -function value_and_pushforward( - f!::F, - y, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), -) where {F} - return value_and_pushforward_twoarg_aux(f!, y, backend, x, dx, extras) +function value_and_pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward( + f!, y, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx) + ) +end + +function value_and_pushforward!(f!::F, y, dy, backend::AbstractADType, x, dx) where {F} + return value_and_pushforward!( + f!, y, dy, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx) + ) end -function value_and_pushforward_twoarg_aux( +function pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F} + return pushforward(f!, y, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx)) +end + +function pushforward!(f!::F, y, dy, backend::AbstractADType, x, dx) where {F} + return pushforward!( + f!, y, dy, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx) + ) +end + +function value_and_pushforward( f!::F, y, backend, x, dx, extras::PullbackPushforwardExtras ) where {F} @compat (; pullback_extras) = extras @@ -220,37 +231,20 @@ function value_and_pushforward_twoarg_aux( end function value_and_pushforward!( - f!::F, - y, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), + f!::F, y, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} y, new_dy = value_and_pushforward(f!, y, backend, x, dx, extras) return y, copyto!(dy, new_dy) end function pushforward( - f!::F, - y, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), + f!::F, y, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward(f!, y, backend, x, dx, extras)[2] end function pushforward!( - f!::F, - y, - dy, - backend::AbstractADType, - x, - dx, - extras::PushforwardExtras=prepare_pushforward(f!, y, backend, x, dx), + f!::F, y, dy, backend::AbstractADType, x, dx, extras::PushforwardExtras ) where {F} return value_and_pushforward!(f!, y, dy, backend, x, dx, extras)[2] end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 267492eb1..4501dd817 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -63,8 +63,26 @@ end ## One argument +function value_gradient_and_hessian(f::F, backend::AbstractADType, x) where {F} + return value_gradient_and_hessian(f, backend, x, prepare_hessian(f, backend, x)) +end + +function value_gradient_and_hessian!(f::F, grad, hess, backend::AbstractADType, x) where {F} + return value_gradient_and_hessian!( + f, grad, hess, backend, x, prepare_hessian(f, backend, x) + ) +end + +function hessian(f::F, backend::AbstractADType, x) where {F} + return hessian(f, backend, x, prepare_hessian(f, backend, x)) +end + +function hessian!(f::F, hess, backend::AbstractADType, x) where {F} + return hessian!(f, hess, backend, x, prepare_hessian(f, backend, x)) +end + function hessian( - f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) + f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -77,11 +95,7 @@ function hessian( end function hessian!( - f::F, - hess, - backend::AbstractADType, - x, - extras::HessianExtras=prepare_hessian(f, backend, x), + f::F, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -94,7 +108,7 @@ function hessian!( end function value_gradient_and_hessian( - f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) + f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} y, grad = value_and_gradient(f, maybe_inner(backend), x, extras.gradient_extras) hess = hessian(f, backend, x, extras) @@ -102,12 +116,7 @@ function value_gradient_and_hessian( end function value_gradient_and_hessian!( - f::F, - grad, - hess, - backend::AbstractADType, - x, - extras::HessianExtras=prepare_hessian(f, backend, x), + f::F, grad, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras ) where {F} y, _ = value_and_gradient!(f, grad, maybe_inner(backend), x, extras.gradient_extras) hessian!(f, hess, backend, x, extras) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index d323db06c..1904f9744 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -80,10 +80,10 @@ function prepare_hvp(f::F, backend::AbstractADType, x, v) where {F} end function prepare_hvp(f::F, backend::SecondOrder, x, v) where {F} - return prepare_hvp_aux(f, backend, x, v, hvp_mode(backend)) + return prepare_hvp(f, backend, x, v, hvp_mode(backend)) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) where {F} # pushforward of many pushforwards in theory, but pushforward of gradient in practice inner_backend = nested(inner(backend)) inner_gradient_closure(z) = gradient(f, inner_backend, z) @@ -93,7 +93,7 @@ function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) return ForwardOverForwardHVPExtras(inner_gradient_closure, outer_pushforward_extras) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) where {F} # pushforward of gradient inner_backend = nested(inner(backend)) inner_gradient_closure(z) = gradient(f, inner_backend, z) @@ -103,7 +103,7 @@ function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) return ForwardOverReverseHVPExtras(inner_gradient_closure, outer_pushforward_extras) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) where {F} # gradient of pushforward # uses v in the closure inner_backend = nested(inner(backend)) @@ -119,7 +119,7 @@ function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) ) end -function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverReverse) where {F} +function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ReverseOverReverse) where {F} # pullback of the gradient inner_backend = nested(inner(backend)) inner_gradient_closure(z) = gradient(f, inner_backend, z) @@ -142,76 +142,84 @@ end ## One argument -function hvp( - f::F, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) -) where {F} +function hvp(f::F, backend::AbstractADType, x, v) where {F} + return hvp(f, backend, x, v, prepare_hvp(f, backend, x, v)) +end + +function hvp!(f::F, p, backend::AbstractADType, x, v) where {F} + return hvp!(f, p, backend, x, v, prepare_hvp(f, backend, x, v)) +end + +function hvp(f::F, backend::AbstractADType, x, v, extras::HVPExtras) where {F} return hvp(f, SecondOrder(backend, backend), x, v, extras) end function hvp( - f::F, backend::SecondOrder, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) + f::F, backend::SecondOrder, x, v, extras::ForwardOverForwardHVPExtras ) where {F} - return hvp_aux(f, backend, x, v, extras) -end - -function hvp_aux(f::F, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward( inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux(f::F, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} +function hvp( + f::F, backend::SecondOrder, x, v, extras::ForwardOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward( inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux(f::F, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} +function hvp( + f::F, backend::SecondOrder, x, v, extras::ReverseOverForwardHVPExtras +) where {F} @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras inner_pushforward_closure = inner_pushforward_closure_generator(v) return gradient(inner_pushforward_closure, outer(backend), x, outer_gradient_extras) end -function hvp_aux(f::F, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} +function hvp( + f::F, backend::SecondOrder, x, v, extras::ReverseOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pullback_extras) = extras return pullback(inner_gradient_closure, outer(backend), x, v, outer_pullback_extras) end -function hvp!( - f::F, p, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) -) where {F} +function hvp!(f::F, p, backend::AbstractADType, x, v, extras::HVPExtras) where {F} return hvp!(f, p, SecondOrder(backend, backend), x, v, extras) end function hvp!( - f::F, p, backend::SecondOrder, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v) + f::F, p, backend::SecondOrder, x, v, extras::ForwardOverForwardHVPExtras ) where {F} - return hvp_aux!(f, p, backend, x, v, extras) -end - -function hvp_aux!(f::F, p, backend, x, v, extras::ForwardOverForwardHVPExtras) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward!( inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux!(f::F, p, backend, x, v, extras::ForwardOverReverseHVPExtras) where {F} +function hvp!( + f::F, p, backend::SecondOrder, x, v, extras::ForwardOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pushforward_extras) = extras return pushforward!( inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras ) end -function hvp_aux!(f::F, p, backend, x, v, extras::ReverseOverForwardHVPExtras) where {F} +function hvp!( + f::F, p, backend::SecondOrder, x, v, extras::ReverseOverForwardHVPExtras +) where {F} @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras inner_pushforward_closure = inner_pushforward_closure_generator(v) return gradient!(inner_pushforward_closure, p, outer(backend), x, outer_gradient_extras) end -function hvp_aux!(f::F, p, backend, x, v, extras::ReverseOverReverseHVPExtras) where {F} +function hvp!( + f::F, p, backend::SecondOrder, x, v, extras::ReverseOverReverseHVPExtras +) where {F} @compat (; inner_gradient_closure, outer_pullback_extras) = extras return pullback!(inner_gradient_closure, p, outer(backend), x, v, outer_pullback_extras) end diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 5921a8a29..c16c08309 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -69,30 +69,43 @@ end ## One argument +function value_derivative_and_second_derivative(f::F, backend::AbstractADType, x) where {F} + return value_derivative_and_second_derivative( + f, backend, x, prepare_second_derivative(f, backend, x) + ) +end + +function value_derivative_and_second_derivative!( + f::F, der, der2, backend::AbstractADType, x +) where {F} + return value_derivative_and_second_derivative!( + f, der, der2, backend, x, prepare_second_derivative(f, backend, x) + ) +end + +function second_derivative(f::F, backend::AbstractADType, x) where {F} + return second_derivative(f, backend, x, prepare_second_derivative(f, backend, x)) +end + +function second_derivative!(f::F, der2, backend::AbstractADType, x) where {F} + return second_derivative!(f, der2, backend, x, prepare_second_derivative(f, backend, x)) +end + function second_derivative( - f::F, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return second_derivative(f, SecondOrder(backend, backend), x, extras) end function second_derivative( - f::F, - backend::SecondOrder, - x, - extras::ClosureSecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::SecondOrder, x, extras::ClosureSecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras return derivative(inner_derivative_closure, outer(backend), x, outer_derivative_extras) end function value_derivative_and_second_derivative( - f::F, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return value_derivative_and_second_derivative( f, SecondOrder(backend, backend), x, extras @@ -100,10 +113,7 @@ function value_derivative_and_second_derivative( end function value_derivative_and_second_derivative( - f::F, - backend::SecondOrder, - x, - extras::ClosureSecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, backend::SecondOrder, x, extras::ClosureSecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras y = f(x) @@ -114,21 +124,13 @@ function value_derivative_and_second_derivative( end function second_derivative!( - f::F, - der2, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der2, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return second_derivative!(f, der2, SecondOrder(backend, backend), x, extras) end function second_derivative!( - f::F, - der2, - backend::SecondOrder, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der2, backend::SecondOrder, x, extras::SecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras return derivative!( @@ -137,12 +139,7 @@ function second_derivative!( end function value_derivative_and_second_derivative!( - f::F, - der, - der2, - backend::AbstractADType, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der, der2, backend::AbstractADType, x, extras::SecondDerivativeExtras ) where {F} return value_derivative_and_second_derivative!( f, der, der2, SecondOrder(backend, backend), x, extras @@ -150,12 +147,7 @@ function value_derivative_and_second_derivative!( end function value_derivative_and_second_derivative!( - f::F, - der, - der2, - backend::SecondOrder, - x, - extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x), + f::F, der, der2, backend::SecondOrder, x, extras::SecondDerivativeExtras ) where {F} @compat (; inner_derivative_closure, outer_derivative_extras) = extras y = f(x) diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index e83ebadc5..f937530c2 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -9,6 +9,8 @@ function test_scen_intact(new_scen, scen) end end +testset_name(k) = k == 1 ? "No prep" : (k == 2 ? "Different point" : "Same point") + ## Pushforward function test_correctness( @@ -26,26 +28,24 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1, dy1 = value_and_pushforward(f, ba, x, dx, extras) - dy2 = pushforward(f, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1 ≈ dy_true - @test dy2 ≈ dy_true - end + y1, dy1 = value_and_pushforward(f, ba, x, dx, extras_tup...) + dy2 = pushforward(f, ba, x, dx, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1 ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -68,31 +68,29 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - dy1_in = mysimilar(y) - y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras) - - dy2_in = mysimilar(y) - dy2 = pushforward!(f, dy2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1_in ≈ dy_true - @test dy1 ≈ dy_true - @test dy2_in ≈ dy_true - @test dy2 ≈ dy_true - end + dy1_in = mysimilar(y) + y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras_tup...) + + dy2_in = mysimilar(y) + dy2 = pushforward!(f, dy2_in, ba, x, dx, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1_in ≈ dy_true + @test dy1 ≈ dy_true + @test dy2_in ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -116,30 +114,28 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in = mysimilar(y) - y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras) - - y2_in = mysimilar(y) - dy2 = pushforward(f!, y2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1 ≈ dy_true - @test dy2 ≈ dy_true - end + y1_in = mysimilar(y) + y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras_tup...) + + y2_in = mysimilar(y) + dy2 = pushforward(f!, y2_in, ba, x, dx, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1 ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -163,32 +159,30 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)), - prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in, dy1_in = mysimilar(y), mysimilar(y) - y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras) - - y2_in, dy2_in = mysimilar(y), mysimilar(y) - dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1_in ≈ dy_true - @test dy1 ≈ dy_true - @test dy2_in ≈ dy_true - @test dy2 ≈ dy_true - end + y1_in, dy1_in = mysimilar(y), mysimilar(y) + y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras_tup...) + + y2_in, dy2_in = mysimilar(y), mysimilar(y) + dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PushforwardExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1_in ≈ dy_true + @test dy1 ≈ dy_true + @test dy2_in ≈ dy_true + @test dy2 ≈ dy_true end end end @@ -213,27 +207,25 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f, ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f, ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1, dx1 = value_and_pullback(f, ba, x, dy, extras) - - dx2 = pullback(f, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1 ≈ dx_true - @test dx2 ≈ dx_true - end + y1, dx1 = value_and_pullback(f, ba, x, dy, extras_tup...) + + dx2 = pullback(f, ba, x, dy, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1 ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -256,31 +248,29 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f, ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f, ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - dx1_in = mysimilar(x) - y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras) - - dx2_in = mysimilar(x) - dx2 = pullback!(f, dx2_in, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1_in ≈ dx_true - @test dx1 ≈ dx_true - @test dx2_in ≈ dx_true - @test dx2 ≈ dx_true - end + dx1_in = mysimilar(x) + y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras_tup...) + + dx2_in = mysimilar(x) + dx2 = pullback!(f, dx2_in, ba, x, dy, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1_in ≈ dx_true + @test dx1 ≈ dx_true + @test dx2_in ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -304,30 +294,28 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in = mysimilar(y) - y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras) - - y2_in = mysimilar(y) - dx2 = pullback(f!, y2_in, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1 ≈ dx_true - @test dx2 ≈ dx_true - end + y1_in = mysimilar(y) + y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras_tup...) + + y2_in = mysimilar(y) + dx2 = pullback(f!, y2_in, ba, x, dy, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1 ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -351,32 +339,30 @@ function test_correctness( new_scen.ref(x, dy) end - for (k, extras) in enumerate([ - prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)), - prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)),), + (prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - y1_in, dx1_in = mysimilar(y), mysimilar(x) - y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras) - - y2_in, dx2_in = mysimilar(y), mysimilar(x) - dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Cotangent value" begin - @test dx1_in ≈ dx_true - @test dx1 ≈ dx_true - @test dx2_in ≈ dx_true - @test dx2 ≈ dx_true - end + y1_in, dx1_in = mysimilar(y), mysimilar(x) + y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras_tup...) + + y2_in, dx2_in = mysimilar(y), mysimilar(x) + dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa PullbackExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1_in ≈ dx_true + @test dx1 ≈ dx_true + @test dx2_in ≈ dx_true + @test dx2 ≈ dx_true end end end @@ -395,27 +381,29 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_derivative(f, ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f, ref_backend, x) else new_scen.ref(x) end - y1, der1 = value_and_derivative(f, ba, x, extras) - - der2 = derivative(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_derivative(f, ba, mycopy_random(x)),) + ]) + y1, der1 = value_and_derivative(f, ba, x, extras_tup...) + der2 = derivative(f, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1 ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1 ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -431,31 +419,34 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_derivative(f, ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f, ref_backend, x) else new_scen.ref(x) end - der1_in = mysimilar(y) - y1, der1 = value_and_derivative!(f, der1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_derivative(f, ba, mycopy_random(x)),) + ]) + der1_in = mysimilar(y) + y1, der1 = value_and_derivative!(f, der1_in, ba, x, extras_tup...) - der2_in = mysimilar(y) - der2 = derivative!(f, der2_in, ba, x, extras) + der2_in = mysimilar(y) + der2 = derivative!(f, der2_in, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1_in ≈ der_true - @test der1 ≈ der_true - @test der2_in ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1_in ≈ der_true + @test der1 ≈ der_true + @test der2_in ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -472,30 +463,33 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in = mysimilar(y) - y1, der1 = value_and_derivative(f!, y1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in = mysimilar(y) + y1, der1 = value_and_derivative(f!, y1_in, ba, x, extras_tup...) - y2_in = mysimilar(y) - der2 = derivative(f!, y2_in, ba, x, extras) + y2_in = mysimilar(y) + der2 = derivative(f!, y2_in, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1 ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1 ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -512,32 +506,35 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in, der1_in = mysimilar(y), mysimilar(y) - y1, der1 = value_and_derivative!(f!, y1_in, der1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in, der1_in = mysimilar(y), mysimilar(y) + y1, der1 = value_and_derivative!(f!, y1_in, der1_in, ba, x, extras_tup...) - y2_in, der2_in = mysimilar(y), mysimilar(y) - der2 = derivative!(f!, y2_in, der2_in, ba, x, extras) + y2_in, der2_in = mysimilar(y), mysimilar(y) + der2 = derivative!(f!, y2_in, der2_in, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa DerivativeExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Derivative value" begin - @test der1_in ≈ der_true - @test der1 ≈ der_true - @test der2_in ≈ der_true - @test der2 ≈ der_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa DerivativeExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Derivative value" begin + @test der1_in ≈ der_true + @test der1 ≈ der_true + @test der2_in ≈ der_true + @test der2 ≈ der_true + end end end test_scen_intact(new_scen, scen) @@ -555,27 +552,30 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_gradient(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, ref_backend, x) else new_scen.ref(x) end - y1, grad1 = value_and_gradient(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_gradient(f, ba, mycopy_random(x)),) + ]) + y1, grad1 = value_and_gradient(f, ba, x, extras_tup...) - grad2 = gradient(f, ba, x, extras) + grad2 = gradient(f, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa GradientExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Gradient value" begin - @test grad1 ≈ grad_true - @test grad2 ≈ grad_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa GradientExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Gradient value" begin + @test grad1 ≈ grad_true + @test grad2 ≈ grad_true + end end end test_scen_intact(new_scen, scen) @@ -591,31 +591,34 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_gradient(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, ref_backend, x) else new_scen.ref(x) end - grad1_in = mysimilar(x) - y1, grad1 = value_and_gradient!(f, grad1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_gradient(f, ba, mycopy_random(x)),) + ]) + grad1_in = mysimilar(x) + y1, grad1 = value_and_gradient!(f, grad1_in, ba, x, extras_tup...) - grad2_in = mysimilar(x) - grad2 = gradient!(f, grad2_in, ba, x, extras) + grad2_in = mysimilar(x) + grad2 = gradient!(f, grad2_in, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa GradientExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Gradient value" begin - @test grad1_in ≈ grad_true - @test grad1 ≈ grad_true - @test grad2_in ≈ grad_true - @test grad2 ≈ grad_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa GradientExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Gradient value" begin + @test grad1_in ≈ grad_true + @test grad1 ≈ grad_true + @test grad2_in ≈ grad_true + @test grad2 ≈ grad_true + end end end test_scen_intact(new_scen, scen) @@ -633,27 +636,30 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_jacobian(f, ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f, ref_backend, x) else new_scen.ref(x) end - y1, jac1 = value_and_jacobian(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_jacobian(f, ba, mycopy_random(x)),) + ]) + y1, jac1 = value_and_jacobian(f, ba, x, extras_tup...) - jac2 = jacobian(f, ba, x, extras) + jac2 = jacobian(f, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1 ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1 ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -669,31 +675,34 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_jacobian(f, ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f, ref_backend, x) else new_scen.ref(x) end - jac1_in = mysimilar(jac_true) - y1, jac1 = value_and_jacobian!(f, jac1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_jacobian(f, ba, mycopy_random(x)),) + ]) + jac1_in = mysimilar(jac_true) + y1, jac1 = value_and_jacobian!(f, jac1_in, ba, x, extras_tup...) - jac2_in = mysimilar(jac_true) - jac2 = jacobian!(f, jac2_in, ba, x, extras) + jac2_in = mysimilar(jac_true) + jac2 = jacobian!(f, jac2_in, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1_in ≈ jac_true - @test jac1 ≈ jac_true - @test jac2_in ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1_in ≈ jac_true + @test jac1 ≈ jac_true + @test jac2_in ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -710,30 +719,33 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in = mysimilar(y) - y1, jac1 = value_and_jacobian(f!, y1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in = mysimilar(y) + y1, jac1 = value_and_jacobian(f!, y1_in, ba, x, extras_tup...) - y2_in = mysimilar(y) - jac2 = jacobian(f!, y2_in, ba, x, extras) + y2_in = mysimilar(y) + jac2 = jacobian(f!, y2_in, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1 ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1 ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -750,32 +762,35 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f!, mysimilar(y), ref_backend, x) else new_scen.ref(x) end - y1_in, jac1_in = mysimilar(y), mysimilar(jac_true) - y1, jac1 = value_and_jacobian!(f!, y1_in, jac1_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)),) + ]) + y1_in, jac1_in = mysimilar(y), mysimilar(jac_true) + y1, jac1 = value_and_jacobian!(f!, y1_in, jac1_in, ba, x, extras_tup...) - y2_in, jac2_in = mysimilar(y), mysimilar(jac_true) - jac2 = jacobian!(f!, y2_in, jac2_in, ba, x, extras) + y2_in, jac2_in = mysimilar(y), mysimilar(jac_true) + jac2 = jacobian!(f!, y2_in, jac2_in, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa JacobianExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Jacobian value" begin - @test jac1_in ≈ jac_true - @test jac1 ≈ jac_true - @test jac2_in ≈ jac_true - @test jac2 ≈ jac_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa JacobianExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Jacobian value" begin + @test jac1_in ≈ jac_true + @test jac1 ≈ jac_true + @test jac2_in ≈ jac_true + @test jac2 ≈ jac_true + end end end test_scen_intact(new_scen, scen) @@ -793,7 +808,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_second_derivative(f, ba, mycopy_random(x)) der1_true = if ref_backend isa AbstractADType derivative(f, maybe_inner(ref_backend), x) else @@ -805,22 +819,26 @@ function test_correctness( new_scen.ref(x) end - der21 = second_derivative(f, ba, x, extras) - y2, der12, der22 = value_derivative_and_second_derivative(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_second_derivative(f, ba, mycopy_random(x)),) + ]) + der21 = second_derivative(f, ba, x, extras_tup...) + y2, der12, der22 = value_derivative_and_second_derivative(f, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa SecondDerivativeExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "First derivative value" begin - @test der12 ≈ der1_true - end - @testset "Second derivative value" begin - @test der21 ≈ der2_true - @test der22 ≈ der2_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa SecondDerivativeExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "First derivative value" begin + @test der12 ≈ der1_true + end + @testset "Second derivative value" begin + @test der21 ≈ der2_true + @test der22 ≈ der2_true + end end end test_scen_intact(new_scen, scen) @@ -836,7 +854,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_second_derivative(f, ba, mycopy_random(x)) der1_true = if ref_backend isa AbstractADType derivative(f, maybe_inner(ref_backend), x) else @@ -848,30 +865,34 @@ function test_correctness( new_scen.ref(x) end - der21_in = mysimilar(y) - der21 = second_derivative!(f, der21_in, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_second_derivative(f, ba, mycopy_random(x)),) + ]) + der21_in = mysimilar(y) + der21 = second_derivative!(f, der21_in, ba, x, extras_tup...) - der12_in, der22_in = mysimilar(y), mysimilar(y) - y2, der12, der22 = value_derivative_and_second_derivative!( - f, der12_in, der22_in, ba, x, extras - ) + der12_in, der22_in = mysimilar(y), mysimilar(y) + y2, der12, der22 = value_derivative_and_second_derivative!( + f, der12_in, der22_in, ba, x, extras_tup... + ) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa SecondDerivativeExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "Derivative value" begin - @test der12_in ≈ der1_true - @test der12 ≈ der1_true - end - @testset "Second derivative value" begin - @test der21_in ≈ der2_true - @test der22_in ≈ der2_true - @test der21 ≈ der2_true - @test der22 ≈ der2_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa SecondDerivativeExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Derivative value" begin + @test der12_in ≈ der1_true + @test der12 ≈ der1_true + end + @testset "Second derivative value" begin + @test der21_in ≈ der2_true + @test der22_in ≈ der2_true + @test der21 ≈ der2_true + @test der22 ≈ der2_true + end end end test_scen_intact(new_scen, scen) @@ -895,21 +916,19 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_hvp_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - p1 = hvp(f, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HVPExtras - end - @testset "HVP value" begin - @test p1 ≈ p_true - end + p1 = hvp(f, ba, x, dx, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa HVPExtras + end + @testset "HVP value" begin + @test p1 ≈ p_true end end end @@ -932,23 +951,21 @@ function test_correctness( new_scen.ref(x, dx) end - for (k, extras) in enumerate([ - prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)), - prepare_hvp_same_point(f, ba, x, mycopy_random(dx)), + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), + (prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),), + (prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),), ]) - testset_name = k == 1 ? "Different point" : "Same point" - @testset "$testset_name" begin - p1_in = mysimilar(x) - p1 = hvp!(f, p1_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HVPExtras - end - @testset "HVP value" begin - @test p1_in ≈ p_true - @test p1 ≈ p_true - end + p1_in = mysimilar(x) + p1 = hvp!(f, p1_in, ba, x, dx, extras_tup...) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa HVPExtras + end + @testset "HVP value" begin + @test p1_in ≈ p_true + @test p1 ≈ p_true end end end @@ -967,7 +984,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_hessian(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else @@ -979,22 +995,26 @@ function test_correctness( new_scen.ref(x) end - hess1 = hessian(f, ba, x, extras) - y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras) + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_hessian(f, ba, mycopy_random(x)),) + ]) + hess1 = hessian(f, ba, x, extras_tup...) + y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras_tup...) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HessianExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "Gradient value" begin - @test grad2 ≈ grad_true - end - @testset "Hessian value" begin - @test hess1 ≈ hess_true - @test hess2 ≈ hess_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa HessianExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Gradient value" begin + @test grad2 ≈ grad_true + end + @testset "Hessian value" begin + @test hess1 ≈ hess_true + @test hess2 ≈ hess_true + end end end test_scen_intact(new_scen, scen) @@ -1010,7 +1030,6 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_hessian(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) else @@ -1022,27 +1041,33 @@ function test_correctness( new_scen.ref(x) end - hess1_in = mysimilar(hess_true) - hess1 = hessian!(f, hess1_in, ba, x, extras) - grad2_in, hess2_in = mysimilar(grad_true), mysimilar(hess_true) - y2, grad2, hess2 = value_gradient_and_hessian!(f, grad2_in, hess2_in, ba, x, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HessianExtras - end - @testset "Primal value" begin - @test y2 ≈ y - end - @testset "Gradient value" begin - @test grad2_in ≈ grad_true - @test grad2 ≈ grad_true - end - @testset "Hessian value" begin - @test hess1_in ≈ hess_true - @test hess2_in ≈ hess_true - @test hess1 ≈ hess_true - @test hess2 ≈ hess_true + @testset "$(testset_name(k))" for (k, extras_tup) in enumerate([ + (), (prepare_hessian(f, ba, mycopy_random(x)),) + ]) + hess1_in = mysimilar(hess_true) + hess1 = hessian!(f, hess1_in, ba, x, extras_tup...) + grad2_in, hess2_in = mysimilar(grad_true), mysimilar(hess_true) + y2, grad2, hess2 = value_gradient_and_hessian!( + f, grad2_in, hess2_in, ba, x, extras_tup... + ) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test isempty(extras_tup) || only(extras_tup) isa HessianExtras + end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Gradient value" begin + @test grad2_in ≈ grad_true + @test grad2 ≈ grad_true + end + @testset "Hessian value" begin + @test hess1_in ≈ hess_true + @test hess2_in ≈ hess_true + @test hess1 ≈ hess_true + @test hess2 ≈ hess_true + end end end test_scen_intact(new_scen, scen)