diff --git a/Project.toml b/Project.toml index 3775ae7..23d37c8 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index fddb6b7..73dc86b 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.2" +julia_version = "1.9.3" manifest_format = "2.0" -project_hash = "33634fd84c64daf14892640fe7fd6289d8fcdf5f" +project_hash = "b3e5f4cf27d760c93b2634d045748b8e8637f186" [[deps.ADTypes]] git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a" @@ -75,6 +75,12 @@ git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.4.2" + [[deps.BangBang]] deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" @@ -124,6 +130,36 @@ git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "968c1365e2992824c3e7a794e30907483f8469a9" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "4.4.1" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.5.0+1" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.2.2" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "5248d9c45712e51e27ba9b30eebec65658c6ce29" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.6.0+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "75923dce4275ead3799b238e10178a68c07dbd3b" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "8.9.4+0" + [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e" @@ -142,6 +178,16 @@ git-tree-sha1 = "3ea2d5853a4d132aedd294047bab373333e92027" uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" version = "0.1.4" +[[deps.ChangesOfVariables]] +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.8" +weakdeps = ["InverseFunctions"] + + [deps.ChangesOfVariables.extensions] + ChangesOfVariablesInverseFunctionsExt = "InverseFunctions" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" @@ -178,13 +224,11 @@ version = "1.0.5+0" git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" version = "0.1.2" +weakdeps = ["InverseFunctions"] [deps.CompositionsBase.extensions] CompositionsBaseInverseFunctionsExt = "InverseFunctions" - [deps.CompositionsBase.weakdeps] - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.ConcreteStructs]] git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" uuid = "2569d6c7-a4a2-43d3-a901-331e8e4be471" @@ -318,22 +362,18 @@ weakdeps = ["SparseArrays", "Statistics"] FillArraysStatisticsExt = "Statistics" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "723a8ec75b26fe278256c89c363e370ba733c12e" +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"] +git-tree-sha1 = "3e2c3704c2173ab4b1935362384ca878b53d4c34" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.4" +version = "0.13.17" [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + AMDGPUExt = "AMDGPU" FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] @@ -372,6 +412,12 @@ git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.5" +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.21.4" + [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] git-tree-sha1 = "cb56ccdd481c0dd7f975ad2b3b62d9eda088f7e2" @@ -398,6 +444,12 @@ version = "0.3.1" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.12" + [[deps.IrrationalConstants]] git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" @@ -426,6 +478,11 @@ git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" version = "0.2.4" +[[deps.JuliennedArrays]] +git-tree-sha1 = "4aeebbfcf0615641ec4b0782b73b638eeeabd62e" +uuid = "5cadff95-7770-533d-a838-a1bf817ee6e0" +version = "0.3.0" + [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" @@ -497,31 +554,27 @@ deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" version = "0.3.26" +weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"] [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] -git-tree-sha1 = "a03c77519ab45eb9a34d3cfe2ca223d79c064323" +git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.1" +version = "1.0.2" [[deps.Lux]] deps = ["ADTypes", "Adapt", "ChainRulesCore", "ConcreteStructs", "Functors", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "Markdown", "Optimisers", "PackageExtensionCompat", "Random", "Reexport", "Setfield", "SparseArrays", "Statistics", "TruncatedStacktraces", "WeightInitializers"] -git-tree-sha1 = "2d9956ba484e78c6285f8ffb30d02fc2e7addd36" +git-tree-sha1 = "78fecc38a73321df15161a481864fce75b66ae84" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" -version = "0.5.3" +version = "0.5.4" [deps.Lux.extensions] LuxComponentArraysExt = "ComponentArrays" @@ -656,19 +709,21 @@ version = "1.3.1" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22" +git-tree-sha1 = "72240e3f5ca031937bd536182cb2c031da5f46dd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.4" +version = "0.8.21" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NNlibCUDA]] +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] +git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.7" [[deps.NaNMath]] deps = ["OpenLibm_jll"] @@ -722,9 +777,9 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "af65afa916284e6c7e89f0ab974500cc9235618e" +git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.0" +version = "0.2.20" [[deps.OrderedCollections]] git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" @@ -808,6 +863,18 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.6.1" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + [[deps.RealDot]] deps = ["LinearAlgebra"] git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" @@ -857,6 +924,12 @@ version = "1.15.1" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.0" + [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -882,6 +955,12 @@ git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" version = "0.9.4" +[[deps.SliceMap]] +deps = ["ForwardDiff", "JuliennedArrays", "StaticArrays", "Tracker", "ZygoteRules"] +git-tree-sha1 = "f988004407ccf6c398a87914eafdd8bc9109e533" +uuid = "82cb661a-3f19-5665-9e27-df437c7e54c8" +version = "0.2.7" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -993,8 +1072,8 @@ uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" [[deps.TaylorDiff]] -deps = ["ChainRules", "ChainRulesCore", "ChainRulesOverloadGeneration", "SpecialFunctions", "SymbolicUtils", "Zygote"] -path = ".." +deps = ["ChainRules", "ChainRulesCore", "ChainRulesOverloadGeneration", "SymbolicUtils", "Zygote"] +git-tree-sha1 = "5d7a5fa2e46e068ef8b591c3fcf147a927601181" uuid = "b36ab563-344f-407b-a36a-4f200bebf99c" version = "0.2.1" @@ -1026,6 +1105,18 @@ git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.23" +[[deps.Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.2.26" + + [deps.Tracker.extensions] + TrackerPDMatsExt = "PDMats" + + [deps.Tracker.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + [[deps.TranscodingStreams]] deps = ["Random", "Test"] git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" @@ -1120,6 +1211,12 @@ git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.3" +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDNN_jll"] +git-tree-sha1 = "ee79f97d07bf875231559f9b3f2649f34fac140b" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.1.0" + [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index dddb7f9..4f8d0e5 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -12,6 +12,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c" TaylorSeries = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" [compat] Zygote = "0.6.55" diff --git a/src/derivative.jl b/src/derivative.jl index 81b66da..c5f2b0b 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -1,38 +1,65 @@ +using SliceMap export derivative """ derivative(f, x::T, order::Int64) derivative(f, x::T, ::Val{N}) -Computes `order`-th derivative of `f` w.r.t. `x`. +Computes `order`-th derivative of `f` w.r.t. scalar `x`. - derivative(f, x::Vector{T}, l::Vector{T}, order::Int64) - derivative(f, x::Vector{T}, l::Vector{T}, ::Val{N}) + derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64) + derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N}) -Computes `order`-th directional derivative of `f` w.r.t. `x` in direction `l`. +Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. + + derivative(f, x::AbstractMatrix{T}, order::Int64) + derivative(f, x::AbstractMatrix{T}, ::Val{N}) + derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64) + derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N}) + +Shorthand notations for multiple calculations. +For a M-by-N matrix, calculate the directional derivative for each column. +For a 1-by-N matrix (row vector), calculate the derivative for each scalar. """ function derivative end -@inline function derivative(f, x::T, order::Int64) where {T <: Number} +# Convenience wrappers for converting orders to value types +# and forward work to core APIs + +@inline function derivative(f, x, order::Int64) derivative(f, x, Val{order + 1}()) end -@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, - order::Int64) where {T <: Number, S <: Number} +@inline function derivative(f, x, l, order::Int64) derivative(f, x, l, Val{order + 1}()) end +# Core APIs + +# Added to help Zygote infer types +make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) + @inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N} t = TaylorScalar{T, N}(x, one(x)) return extract_derivative(f(t), N) end -# Need to rewrite like this to help Zygote infer types -make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) - @inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, vN::Val{N}) where {T <: Number, S <: Number, N} - t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l) + t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) + # equivalent to map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) end + +# shorthand notations for matrices + +@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: Number, N} + size(x)[1] != 1 && @warn "x is not a row vector." + mapcols(u -> derivative(f, u[1], vN), x) +end + +@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S}, + vN::Val{N}) where {T <: Number, S <: Number, N} + mapcols(u -> derivative(f, u, l, vN), x) +end diff --git a/test/Project.toml b/test/Project.toml index 0accb85..72bbeb9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,3 +2,4 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/scalar.jl b/test/scalar.jl index fea2470..444fc68 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -1 +1,6 @@ -@testset "Scalar" begin end + +@testset "Scalar" begin + g(x) = x^3 + @test derivative(g, 1.0, 1) ≈ 3 + @test derivative(g, [2.0 3.0], 1) ≈ [12.0 27.0] +end diff --git a/test/vector.jl b/test/vector.jl index 801f38b..7b62531 100644 --- a/test/vector.jl +++ b/test/vector.jl @@ -2,4 +2,5 @@ @testset "Vector" begin g(x) = x[1] * x[1] + x[2] * x[2] @test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0 + @test derivative(g, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [6.0 10.0] end diff --git a/test/zygote.jl b/test/zygote.jl index caeab9a..a1d2c2b 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -1,16 +1,33 @@ -using Zygote +using Zygote, LinearAlgebra @testset "Zygote for mixed derivative" begin some_number = 0.7 + some_numbers = [0.3 0.4 0.1;] for f in (exp, log, sqrt, sin, asin, sinh, asinh) @test gradient(x -> derivative(f, x, 2), some_number)[1] ≈ derivative(f, some_number, 3) + derivative_result = vec(derivative(f, some_numbers, 3)) + @test Zygote.jacobian(x -> derivative(f, x, 2), some_numbers)[1] ≈ + diagm(derivative_result) end + + some_matrix = [0.7 0.1; 0.4 0.2] + f = x -> sum(tanh.(x), dims = 1) + dfdx1(m, x) = derivative(u -> sum(m(u)), x, [1.0, 0.0], 1) + dfdx2(m, x) = derivative(u -> sum(m(u)), x, [0.0, 1.0], 1) + res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x) + grads = Zygote.gradient(some_matrix) do x + sum(res(f, x)) + end + expected_grads = x -> -2 * sinh(x) / cosh(x)^3 + @test grads[1] ≈ [1 0; 0 2] * expected_grads.(some_matrix) + @test gradient(x -> derivative(x -> x * x, x, 1), 5.0)[1] ≈ 2.0 g(x) = x[1] * x[1] + x[2] * x[2] - @test gradient(x -> derivative(g, x, [1.0, 0.0], 1), [1.0, 2.0])[1] ≈ [2.0, 0.0] + @test gradient(x -> derivative(g, x, [1.0, 0.0], 1), + [1.0, 2.0])[1] ≈ [2.0, 0.0] end @testset "Zygote for parameter optimization" begin