From 91d781acc76d1f7db005b95f824e78add9d89382 Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Wed, 23 Aug 2023 15:48:50 +0200 Subject: [PATCH 1/9] TaylorDiff deals with matrix --- Project.toml | 1 + src/derivative.jl | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/Project.toml b/Project.toml index 3775ae79..23d37c8e 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/derivative.jl b/src/derivative.jl index 81b66da7..bafeba1a 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -1,4 +1,5 @@ +using SliceMap export derivative """ @@ -18,16 +19,31 @@ function derivative end derivative(f, x, Val{order + 1}()) end + +@inline function derivative(f, x::M, order::Int64) where {M <: AbstractMatrix{<:Number}} + mapcols(u -> derivative(f, u[1], order), x) +end + @inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, order::Int64) where {T <: Number, S <: Number} derivative(f, x, l, Val{order + 1}()) end +@inline function derivative(f, x::M, l::V, + order::Int64) where {M <: AbstractMatrix{<:Number}, V <: AbstractVector{<:Number}} + @info "test" + mapcols(u -> derivative(f, u, l, order), x) +end + @inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N} t = TaylorScalar{T, N}(x, one(x)) return extract_derivative(f(t), N) end +@inline function derivative(f, x::M, ::Val{N}) where {M <: AbstractMatrix{<:Number}, N} + mapcols(u -> derivative(f, u[1], N), x) +end + # Need to rewrite like this to help Zygote infer types make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) @@ -36,3 +52,8 @@ make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) end + +@inline function derivative(f, x::M, l::V, + vN::Val{N}) where {M <: AbstractMatrix{<:Number}, V <: AbstractVector{<:Number}, N} + mapcols(u -> derivative(f, u, l, vN), x) +end \ No newline at end of file From 6154c9cd992b1bffafbb81120f6d744d3c87bcb7 Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Sun, 27 Aug 2023 18:47:51 +0200 Subject: [PATCH 2/9] Add tests to matrix differentiation --- src/derivative.jl | 5 ++--- test/Project.toml | 1 + test/vector.jl | 1 + test/zygote.jl | 21 +++++++++++++++++++-- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/derivative.jl b/src/derivative.jl index bafeba1a..2463a657 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -29,9 +29,8 @@ end derivative(f, x, l, Val{order + 1}()) end -@inline function derivative(f, x::M, l::V, - order::Int64) where {M <: AbstractMatrix{<:Number}, V <: AbstractVector{<:Number}} - @info "test" +@inline function derivative(f, x::M, l::L, + order::Int64) where {M <: AbstractMatrix{<:Number}, L <: AbstractVector{<:Number}} mapcols(u -> derivative(f, u, l, order), x) end diff --git a/test/Project.toml b/test/Project.toml index 0accb85a..72bbeb94 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,3 +2,4 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/vector.jl b/test/vector.jl index 801f38b1..f5ff3913 100644 --- a/test/vector.jl +++ b/test/vector.jl @@ -2,4 +2,5 @@ @testset "Vector" begin g(x) = x[1] * x[1] + x[2] * x[2] @test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0 + @test derivative(g, [1.0; 2.0;; 2.0; 3.0], [1.0, 1.0], 1) ≈ [6.0 10.0] end diff --git a/test/zygote.jl b/test/zygote.jl index caeab9a1..1a6fbe36 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -1,16 +1,33 @@ -using Zygote +using Zygote, LinearAlgebra @testset "Zygote for mixed derivative" begin some_number = 0.7 + some_numbers = [0.3 0.4 2.0;] for f in (exp, log, sqrt, sin, asin, sinh, asinh) @test gradient(x -> derivative(f, x, 2), some_number)[1] ≈ derivative(f, some_number, 3) + derivative_result = vec(derivative(f, some_numbers, 3)) + @test Zygote.jacobian(x -> derivative(f, x, 2), some_numbers)[1] ≈ + diagm(derivative_result) end + + some_matrix = [0.7; 0.1;; 0.4; 0.2] + f = x -> sum(tanh.(x), dims = 1) + dfdx1(m, x) = derivative(u -> sum(m(u)), x, [1.0, 0.0], 1) + dfdx2(m, x) = derivative(u -> sum(m(u)), x, [0.0, 1.0], 1) + res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x) + grads = Zygote.gradient(some_matrix) do x + sum(res(f, x)) + end + expected_grads = x -> -2 * sinh(x) / cosh(x)^3 + @test grads[1] ≈ [1 0; 0 2] * expected_grads.(some_matrix) + @test gradient(x -> derivative(x -> x * x, x, 1), 5.0)[1] ≈ 2.0 g(x) = x[1] * x[1] + x[2] * x[2] - @test gradient(x -> derivative(g, x, [1.0, 0.0], 1), [1.0, 2.0])[1] ≈ [2.0, 0.0] + @test gradient(x -> derivative(g, x, [1.0, 0.0], 1), + [1.0, 2.0])[1] ≈ [2.0, 0.0] end @testset "Zygote for parameter optimization" begin From 82df7b431833cc202065182c51d7d1c7214717c3 Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Sun, 27 Aug 2023 18:57:30 +0200 Subject: [PATCH 3/9] Fix test --- test/zygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/zygote.jl b/test/zygote.jl index 1a6fbe36..4b17811a 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -2,7 +2,7 @@ using Zygote, LinearAlgebra @testset "Zygote for mixed derivative" begin some_number = 0.7 - some_numbers = [0.3 0.4 2.0;] + some_numbers = [0.3 0.4 0.1;] for f in (exp, log, sqrt, sin, asin, sinh, asinh) @test gradient(x -> derivative(f, x, 2), some_number)[1] ≈ derivative(f, some_number, 3) From 88d37d549baad82a9b5ebfa5fc245fc65280093b Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Mon, 28 Aug 2023 13:35:47 +0200 Subject: [PATCH 4/9] Fix broken test + formatting --- benchmark/Manifest.toml | 385 ++++++++++++++-------------------------- benchmark/Project.toml | 1 + src/derivative.jl | 2 +- test/zygote.jl | 2 +- 4 files changed, 139 insertions(+), 251 deletions(-) diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index fddb6b7e..88ce710f 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.2" +julia_version = "1.8.4" manifest_format = "2.0" -project_hash = "33634fd84c64daf14892640fe7fd6289d8fcdf5f" +project_hash = "b3e5f4cf27d760c93b2634d045748b8e8637f186" [[deps.ADTypes]] git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a" @@ -10,15 +10,10 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" version = "0.2.1" [[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" [[deps.AbstractTrees]] git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" @@ -30,10 +25,6 @@ deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.6.2" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -50,22 +41,6 @@ git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" version = "7.4.11" - [deps.ArrayInterface.extensions] - ArrayInterfaceBandedMatricesExt = "BandedMatrices" - ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" - ArrayInterfaceCUDAExt = "CUDA" - ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" - ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" - ArrayInterfaceTrackerExt = "Tracker" - - [deps.ArrayInterface.weakdeps] - BandedMatrices = "aae01518-5342-5314-be14-df237901396f" - BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" - StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -75,26 +50,18 @@ git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.4.2" + [[deps.BangBang]] deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" version = "0.3.39" - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -124,6 +91,36 @@ git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "968c1365e2992824c3e7a794e30907483f8469a9" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "4.4.1" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.5.0+1" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.2.2" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "5248d9c45712e51e27ba9b30eebec65658c6ce29" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.6.0+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "75923dce4275ead3799b238e10178a68c07dbd3b" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "8.9.4+0" + [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e" @@ -142,6 +139,12 @@ git-tree-sha1 = "3ea2d5853a4d132aedd294047bab373333e92027" uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" version = "0.1.4" +[[deps.ChangesOfVariables]] +deps = ["InverseFunctions", "LinearAlgebra", "Test"] +git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.8" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" @@ -160,31 +163,21 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[deps.Compat]] -deps = ["UUIDs"] +deps = ["Dates", "LinearAlgebra", "UUIDs"] git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" version = "4.9.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" +version = "1.0.1+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" version = "0.1.2" - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - - [deps.CompositionsBase.weakdeps] - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.ConcreteStructs]] git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" uuid = "2569d6c7-a4a2-43d3-a901-331e8e4be471" @@ -202,14 +195,6 @@ git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" version = "1.5.3" - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - [[deps.ContextVariablesX]] deps = ["Compat", "Logging", "UUIDs"] git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" @@ -243,9 +228,7 @@ version = "0.1.2" [[deps.DelimitedFiles]] deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" [[deps.DiffResults]] deps = ["StaticArraysCore"] @@ -307,43 +290,22 @@ version = "0.1.1" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random"] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "1.6.1" -weakdeps = ["SparseArrays", "Statistics"] - - [deps.FillArrays.extensions] - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "723a8ec75b26fe278256c89c363e370ba733c12e" +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"] +git-tree-sha1 = "3e2c3704c2173ab4b1935362384ca878b53d4c34" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.4" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "0.13.17" [[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" [[deps.FunctionWrappers]] git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" @@ -372,6 +334,12 @@ git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.5" +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.21.4" + [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] git-tree-sha1 = "cb56ccdd481c0dd7f975ad2b3b62d9eda088f7e2" @@ -398,6 +366,12 @@ version = "0.3.1" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.12" + [[deps.IrrationalConstants]] git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" @@ -426,18 +400,17 @@ git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" version = "0.2.4" +[[deps.JuliennedArrays]] +git-tree-sha1 = "4aeebbfcf0615641ec4b0782b73b638eeeabd62e" +uuid = "5cadff95-7770-533d-a838-a1bf817ee6e0" +version = "0.3.0" + [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" version = "0.9.8" - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" @@ -489,60 +462,29 @@ version = "1.10.2+0" uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +deps = ["Libdl", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" version = "0.3.26" - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] -git-tree-sha1 = "a03c77519ab45eb9a34d3cfe2ca223d79c064323" +git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.1" +version = "1.0.2" [[deps.Lux]] deps = ["ADTypes", "Adapt", "ChainRulesCore", "ConcreteStructs", "Functors", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "Markdown", "Optimisers", "PackageExtensionCompat", "Random", "Reexport", "Setfield", "SparseArrays", "Statistics", "TruncatedStacktraces", "WeightInitializers"] -git-tree-sha1 = "2d9956ba484e78c6285f8ffb30d02fc2e7addd36" +git-tree-sha1 = "78fecc38a73321df15161a481864fce75b66ae84" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" -version = "0.5.3" - - [deps.Lux.extensions] - LuxComponentArraysExt = "ComponentArrays" - LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] - LuxComponentArraysTrackerExt = ["ComponentArrays", "Tracker"] - LuxComponentArraysZygoteExt = ["ComponentArrays", "Zygote"] - LuxFluxTransformExt = "Flux" - LuxLuxAMDGPUExt = "LuxAMDGPU" - LuxLuxCUDAExt = "LuxCUDA" - LuxTrackerExt = "Tracker" - LuxZygoteExt = "Zygote" - - [deps.Lux.weakdeps] - ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" - FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" - Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" - LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" - LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.5.4" [[deps.LuxCore]] deps = ["DocStringExtensions", "Functors", "Random", "Setfield"] @@ -556,41 +498,12 @@ git-tree-sha1 = "e67d2206f6f05f534dccbed1df2b60e452ce4d0d" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" version = "0.1.7" - [deps.LuxDeviceUtils.extensions] - LuxDeviceUtilsComponentArraysExt = "ComponentArrays" - LuxDeviceUtilsFillArraysExt = "FillArrays" - LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" - LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" - LuxDeviceUtilsMetalExt = "Metal" - LuxDeviceUtilsZygoteExt = "Zygote" - - [deps.LuxDeviceUtils.weakdeps] - ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" - FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" - LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" - LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - [[deps.LuxLib]] deps = ["ChainRulesCore", "KernelAbstractions", "Markdown", "NNlib", "PackageExtensionCompat", "Random", "Reexport", "Statistics"] git-tree-sha1 = "06e1f04441a8835413b48c84c016313c16e1687b" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" version = "0.3.2" - [deps.LuxLib.extensions] - LuxLibForwardDiffExt = "ForwardDiff" - LuxLibLuxCUDAExt = "LuxCUDA" - LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] - LuxLibReverseDiffExt = "ReverseDiff" - LuxLibTrackerExt = "Tracker" - - [deps.LuxLib.weakdeps] - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -621,7 +534,7 @@ version = "1.1.7" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" +version = "2.28.0+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -640,7 +553,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" +version = "2022.2.1" [[deps.MultivariatePolynomials]] deps = ["ChainRulesCore", "DataStructures", "LinearAlgebra", "MutableArithmetics"] @@ -656,19 +569,15 @@ version = "1.3.1" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22" +git-tree-sha1 = "72240e3f5ca031937bd536182cb2c031da5f46dd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.4" +version = "0.8.21" - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +[[deps.NNlibCUDA]] +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] +git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.7" [[deps.NaNMath]] deps = ["OpenLibm_jll"] @@ -695,7 +604,7 @@ version = "0.2.4" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" +version = "0.3.20+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -722,9 +631,9 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "af65afa916284e6c7e89f0ab974500cc9235618e" +git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.0" +version = "0.2.20" [[deps.OrderedCollections]] git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" @@ -732,10 +641,10 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.6.2" [[deps.PackageExtensionCompat]] +deps = ["Requires", "TOML"] git-tree-sha1 = "f9b1e033c2b1205cf30fd119f4e50881316c1923" uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" version = "1.0.1" -weakdeps = ["Requires", "TOML"] [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] @@ -749,9 +658,9 @@ uuid = "570af359-4316-4cb7-8c74-252c00c2016b" version = "1.1.1" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" +version = "1.8.0" [[deps.PkgBenchmark]] deps = ["BenchmarkTools", "Dates", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Pkg", "Printf", "TerminalLoggers", "UUIDs"] @@ -764,10 +673,6 @@ deps = ["Adapt", "ArrayInterface", "ForwardDiff", "Requires"] git-tree-sha1 = "f739b1b3cc7b9949af3b35089931f2b58c289163" uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" version = "0.4.12" -weakdeps = ["ReverseDiff"] - - [deps.PreallocationTools.extensions] - PreallocationToolsReverseDiffExt = "ReverseDiff" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -808,6 +713,18 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.6.1" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + [[deps.RealDot]] deps = ["LinearAlgebra"] git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" @@ -826,16 +743,6 @@ git-tree-sha1 = "7ed35fb5f831aaf09c2d7c8736d44667a1afdcb0" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" version = "2.38.7" - [deps.RecursiveArrayTools.extensions] - RecursiveArrayToolsMeasurementsExt = "Measurements" - RecursiveArrayToolsTrackerExt = "Tracker" - RecursiveArrayToolsZygoteExt = "Zygote" - - [deps.RecursiveArrayTools.weakdeps] - Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" @@ -857,6 +764,12 @@ version = "1.15.1" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.0" + [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -882,6 +795,12 @@ git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" version = "0.9.4" +[[deps.SliceMap]] +deps = ["ForwardDiff", "JuliennedArrays", "StaticArrays", "Tracker", "ZygoteRules"] +git-tree-sha1 = "f988004407ccf6c398a87914eafdd8bc9109e533" +uuid = "82cb661a-3f19-5665-9e27-df437c7e54c8" +version = "0.2.7" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -892,18 +811,14 @@ uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" version = "1.1.1" [[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.3.1" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" [[deps.SplittablesBase]] deps = ["Setfield", "Test"] @@ -912,14 +827,10 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" uuid = "90137ffa-7385-5640-81b9-e52037218182" version = "1.6.2" -weakdeps = ["Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" @@ -929,7 +840,6 @@ version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -953,11 +863,6 @@ version = "0.6.15" deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" - [[deps.SymbolicIndexingInterface]] deps = ["DocStringExtensions"] git-tree-sha1 = "f8ab052bfcbdb9b48fad2c80c873aa0d0344dfe5" @@ -973,7 +878,7 @@ version = "1.2.0" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" +version = "1.0.0" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -990,11 +895,11 @@ version = "1.10.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" +version = "1.10.1" [[deps.TaylorDiff]] -deps = ["ChainRules", "ChainRulesCore", "ChainRulesOverloadGeneration", "SpecialFunctions", "SymbolicUtils", "Zygote"] -path = ".." +deps = ["ChainRules", "ChainRulesCore", "ChainRulesOverloadGeneration", "SymbolicUtils", "Zygote"] +git-tree-sha1 = "5d7a5fa2e46e068ef8b591c3fcf147a927601181" uuid = "b36ab563-344f-407b-a36a-4f200bebf99c" version = "0.2.1" @@ -1004,12 +909,6 @@ git-tree-sha1 = "50718b4fc1ce20cecf28d85215028c78b4d875c2" uuid = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea" version = "0.15.2" - [deps.TaylorSeries.extensions] - TaylorSeriesIAExt = "IntervalArithmetic" - - [deps.TaylorSeries.weakdeps] - IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" - [[deps.TerminalLoggers]] deps = ["LeftChildRightSiblingTrees", "Logging", "Markdown", "Printf", "ProgressLogging", "UUIDs"] git-tree-sha1 = "f133fab380933d042f6796eda4e130272ba520ca" @@ -1026,6 +925,12 @@ git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.23" +[[deps.Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.2.26" + [[deps.TranscodingStreams]] deps = ["Random", "Test"] git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" @@ -1038,20 +943,6 @@ git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" version = "0.4.78" - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - [[deps.TruncatedStacktraces]] deps = ["InteractiveUtils", "MacroTools", "Preferences"] git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" @@ -1096,7 +987,7 @@ version = "0.1.1" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" +version = "1.2.12+3" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -1104,26 +995,22 @@ git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" version = "0.6.63" - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.3" +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDNN_jll"] +git-tree-sha1 = "ee79f97d07bf875231559f9b3f2649f34fac140b" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.1.0" + [[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" +version = "5.1.1+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/benchmark/Project.toml b/benchmark/Project.toml index dddb7f94..4f8d0e59 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -12,6 +12,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c" TaylorSeries = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" [compat] Zygote = "0.6.55" diff --git a/src/derivative.jl b/src/derivative.jl index 2463a657..7b79e7e6 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -55,4 +55,4 @@ end @inline function derivative(f, x::M, l::V, vN::Val{N}) where {M <: AbstractMatrix{<:Number}, V <: AbstractVector{<:Number}, N} mapcols(u -> derivative(f, u, l, vN), x) -end \ No newline at end of file +end diff --git a/test/zygote.jl b/test/zygote.jl index 4b17811a..a1d2c2bc 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -11,7 +11,7 @@ using Zygote, LinearAlgebra diagm(derivative_result) end - some_matrix = [0.7; 0.1;; 0.4; 0.2] + some_matrix = [0.7 0.1; 0.4 0.2] f = x -> sum(tanh.(x), dims = 1) dfdx1(m, x) = derivative(u -> sum(m(u)), x, [1.0, 0.0], 1) dfdx2(m, x) = derivative(u -> sum(m(u)), x, [0.0, 1.0], 1) From 70822df0792683cf120db5469596a2dc964c90a8 Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Mon, 28 Aug 2023 13:40:38 +0200 Subject: [PATCH 5/9] Fix broken tests --- test/vector.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/vector.jl b/test/vector.jl index f5ff3913..7b625312 100644 --- a/test/vector.jl +++ b/test/vector.jl @@ -2,5 +2,5 @@ @testset "Vector" begin g(x) = x[1] * x[1] + x[2] * x[2] @test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0 - @test derivative(g, [1.0; 2.0;; 2.0; 3.0], [1.0, 1.0], 1) ≈ [6.0 10.0] + @test derivative(g, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [6.0 10.0] end From 50700191497a7ca7f96e55cca442f64a958e9b28 Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Mon, 28 Aug 2023 13:52:45 +0200 Subject: [PATCH 6/9] Clean types in derivative.jl --- src/derivative.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/derivative.jl b/src/derivative.jl index 7b79e7e6..1ed38bae 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -29,8 +29,8 @@ end derivative(f, x, l, Val{order + 1}()) end -@inline function derivative(f, x::M, l::L, - order::Int64) where {M <: AbstractMatrix{<:Number}, L <: AbstractVector{<:Number}} +@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, + order::Int64) where {T <: Number} mapcols(u -> derivative(f, u, l, order), x) end @@ -39,7 +39,7 @@ end return extract_derivative(f(t), N) end -@inline function derivative(f, x::M, ::Val{N}) where {M <: AbstractMatrix{<:Number}, N} +@inline function derivative(f, x::AbstractMatrix{<:Number}, ::Val{N}) where {N} mapcols(u -> derivative(f, u[1], N), x) end @@ -52,7 +52,7 @@ make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t return extract_derivative(f(t), N) end -@inline function derivative(f, x::M, l::V, - vN::Val{N}) where {M <: AbstractMatrix{<:Number}, V <: AbstractVector{<:Number}, N} +@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, + vN::Val{N}) where {T <: Number, N} mapcols(u -> derivative(f, u, l, vN), x) end From fbc1de1421175bcce206e0ff3d3da664fbe8ae64 Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Tue, 29 Aug 2023 07:54:40 +0200 Subject: [PATCH 7/9] Add a warning and increase test coverage --- src/derivative.jl | 5 +++-- test/scalar.jl | 7 ++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/derivative.jl b/src/derivative.jl index 1ed38bae..426c5eb1 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -19,8 +19,8 @@ function derivative end derivative(f, x, Val{order + 1}()) end - -@inline function derivative(f, x::M, order::Int64) where {M <: AbstractMatrix{<:Number}} +@inline function derivative(f, x::AbstractMatrix{<:Number}, order::Int64) + size(x)[1] != 1 && @warn "x is not a row vector." mapcols(u -> derivative(f, u[1], order), x) end @@ -40,6 +40,7 @@ end end @inline function derivative(f, x::AbstractMatrix{<:Number}, ::Val{N}) where {N} + size(x)[1] != 1 && @warn "x is not a row vector." mapcols(u -> derivative(f, u[1], N), x) end diff --git a/test/scalar.jl b/test/scalar.jl index fea2470d..4d790d2a 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -1 +1,6 @@ -@testset "Scalar" begin end + +@testset "Scalar" begin + g(x) = x^3 + @test derivative(g, 1.0, 1) ≈ 3 + @test derivative(g, [2.0 3.0], 1) ≈ [12.0 27.0] +end From d8a164afe7ae9e09ea6cdd9e8a31cc5db38abdae Mon Sep 17 00:00:00 2001 From: Matthieu Barreau Date: Tue, 29 Aug 2023 10:01:38 +0200 Subject: [PATCH 8/9] Formatting + fix in test --- src/derivative.jl | 27 +++++++++++---------------- test/scalar.jl | 2 +- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/derivative.jl b/src/derivative.jl index 426c5eb1..98191570 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -4,42 +4,37 @@ export derivative """ derivative(f, x::T, order::Int64) + derivative(f, x::AbstractMatrix{T}, order::Int64) derivative(f, x::T, ::Val{N}) + derivative(f, x::AbstractMatrix{T}, ::Val{N}) Computes `order`-th derivative of `f` w.r.t. `x`. - derivative(f, x::Vector{T}, l::Vector{T}, order::Int64) - derivative(f, x::Vector{T}, l::Vector{T}, ::Val{N}) + derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64) + derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64) + derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N}) + derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N}) Computes `order`-th directional derivative of `f` w.r.t. `x` in direction `l`. """ function derivative end -@inline function derivative(f, x::T, order::Int64) where {T <: Number} +@inline function derivative(f, x::Union{T, AbstractMatrix{T}}, + order::Int64) where {T <: Number} derivative(f, x, Val{order + 1}()) end -@inline function derivative(f, x::AbstractMatrix{<:Number}, order::Int64) - size(x)[1] != 1 && @warn "x is not a row vector." - mapcols(u -> derivative(f, u[1], order), x) -end - -@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, - order::Int64) where {T <: Number, S <: Number} +@inline function derivative(f, x::Union{AbstractVector{T}, AbstractMatrix{T}}, + l::AbstractVector{S}, order::Int64) where {T <: Number, S <: Number} derivative(f, x, l, Val{order + 1}()) end -@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, - order::Int64) where {T <: Number} - mapcols(u -> derivative(f, u, l, order), x) -end - @inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N} t = TaylorScalar{T, N}(x, one(x)) return extract_derivative(f(t), N) end -@inline function derivative(f, x::AbstractMatrix{<:Number}, ::Val{N}) where {N} +@inline function derivative(f, x::AbstractMatrix{<:Number}, N::Val) size(x)[1] != 1 && @warn "x is not a row vector." mapcols(u -> derivative(f, u[1], N), x) end diff --git a/test/scalar.jl b/test/scalar.jl index 4d790d2a..444fc68e 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -1,5 +1,5 @@ -@testset "Scalar" begin +@testset "Scalar" begin g(x) = x^3 @test derivative(g, 1.0, 1) ≈ 3 @test derivative(g, [2.0 3.0], 1) ≈ [12.0 27.0] From 6ad719f322e78f042b2db03237dbd26986c99cbb Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Sat, 2 Sep 2023 08:11:25 +0000 Subject: [PATCH 9/9] clean up matrix APIs --- benchmark/Manifest.toml | 256 ++++++++++++++++++++++++++++++++++++---- src/derivative.jl | 51 ++++---- 2 files changed, 264 insertions(+), 43 deletions(-) diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index 88ce710f..73dc86b4 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.8.4" +julia_version = "1.9.3" manifest_format = "2.0" project_hash = "b3e5f4cf27d760c93b2634d045748b8e8637f186" @@ -10,10 +10,15 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" version = "0.2.1" [[deps.AbstractFFTs]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +deps = ["LinearAlgebra"] git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" [[deps.AbstractTrees]] git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" @@ -25,6 +30,10 @@ deps = ["LinearAlgebra", "Requires"] git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.6.2" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -41,6 +50,22 @@ git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" version = "7.4.11" + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -62,6 +87,20 @@ git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" version = "0.3.39" + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -140,10 +179,14 @@ uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" version = "0.1.4" [[deps.ChangesOfVariables]] -deps = ["InverseFunctions", "LinearAlgebra", "Test"] +deps = ["LinearAlgebra", "Test"] git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.8" +weakdeps = ["InverseFunctions"] + + [deps.ChangesOfVariables.extensions] + ChangesOfVariablesInverseFunctionsExt = "InverseFunctions" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -163,20 +206,28 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] +deps = ["UUIDs"] git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" version = "4.9.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.1+0" +version = "1.0.5+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" [[deps.ConcreteStructs]] git-tree-sha1 = "f749037478283d372048690eb3b5f92a79432b34" @@ -195,6 +246,14 @@ git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" version = "1.5.3" + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [[deps.ContextVariablesX]] deps = ["Compat", "Logging", "UUIDs"] git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" @@ -228,7 +287,9 @@ version = "0.1.2" [[deps.DelimitedFiles]] deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" [[deps.DiffResults]] deps = ["StaticArraysCore"] @@ -290,10 +351,15 @@ version = "0.1.1" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +deps = ["LinearAlgebra", "Random"] git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "1.6.1" +weakdeps = ["SparseArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" [[deps.Flux]] deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"] @@ -301,11 +367,23 @@ git-tree-sha1 = "3e2c3704c2173ab4b1935362384ca878b53d4c34" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" version = "0.13.17" + [deps.Flux.extensions] + AMDGPUExt = "AMDGPU" + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + [[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" [[deps.FunctionWrappers]] git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" @@ -411,6 +489,12 @@ git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" version = "0.9.8" + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" @@ -462,14 +546,20 @@ version = "1.10.2+0" uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" version = "0.3.26" +weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"] + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -486,6 +576,27 @@ git-tree-sha1 = "78fecc38a73321df15161a481864fce75b66ae84" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" version = "0.5.4" + [deps.Lux.extensions] + LuxComponentArraysExt = "ComponentArrays" + LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] + LuxComponentArraysTrackerExt = ["ComponentArrays", "Tracker"] + LuxComponentArraysZygoteExt = ["ComponentArrays", "Zygote"] + LuxFluxTransformExt = "Flux" + LuxLuxAMDGPUExt = "LuxAMDGPU" + LuxLuxCUDAExt = "LuxCUDA" + LuxTrackerExt = "Tracker" + LuxZygoteExt = "Zygote" + + [deps.Lux.weakdeps] + ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.LuxCore]] deps = ["DocStringExtensions", "Functors", "Random", "Setfield"] git-tree-sha1 = "f2dafe0ddcecf06247b40dbf336acd14e0adce6d" @@ -498,12 +609,41 @@ git-tree-sha1 = "e67d2206f6f05f534dccbed1df2b60e452ce4d0d" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" version = "0.1.7" + [deps.LuxDeviceUtils.extensions] + LuxDeviceUtilsComponentArraysExt = "ComponentArrays" + LuxDeviceUtilsFillArraysExt = "FillArrays" + LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" + LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" + LuxDeviceUtilsMetalExt = "Metal" + LuxDeviceUtilsZygoteExt = "Zygote" + + [deps.LuxDeviceUtils.weakdeps] + ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.LuxLib]] deps = ["ChainRulesCore", "KernelAbstractions", "Markdown", "NNlib", "PackageExtensionCompat", "Random", "Reexport", "Statistics"] git-tree-sha1 = "06e1f04441a8835413b48c84c016313c16e1687b" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" version = "0.3.2" + [deps.LuxLib.extensions] + LuxLibForwardDiffExt = "ForwardDiff" + LuxLibLuxCUDAExt = "LuxCUDA" + LuxLibLuxCUDATrackerExt = ["LuxCUDA", "Tracker"] + LuxLibReverseDiffExt = "ReverseDiff" + LuxLibTrackerExt = "Tracker" + + [deps.LuxLib.weakdeps] + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -534,7 +674,7 @@ version = "1.1.7" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" +version = "2.28.2+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -553,7 +693,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.2.1" +version = "2022.10.11" [[deps.MultivariatePolynomials]] deps = ["ChainRulesCore", "DataStructures", "LinearAlgebra", "MutableArithmetics"] @@ -573,6 +713,12 @@ git-tree-sha1 = "72240e3f5ca031937bd536182cb2c031da5f46dd" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.8.21" + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + [[deps.NNlibCUDA]] deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" @@ -604,7 +750,7 @@ version = "0.2.4" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.20+0" +version = "0.3.21+4" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -641,10 +787,10 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.6.2" [[deps.PackageExtensionCompat]] -deps = ["Requires", "TOML"] git-tree-sha1 = "f9b1e033c2b1205cf30fd119f4e50881316c1923" uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" version = "1.0.1" +weakdeps = ["Requires", "TOML"] [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] @@ -658,9 +804,9 @@ uuid = "570af359-4316-4cb7-8c74-252c00c2016b" version = "1.1.1" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" +version = "1.9.2" [[deps.PkgBenchmark]] deps = ["BenchmarkTools", "Dates", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Pkg", "Printf", "TerminalLoggers", "UUIDs"] @@ -673,6 +819,10 @@ deps = ["Adapt", "ArrayInterface", "ForwardDiff", "Requires"] git-tree-sha1 = "f739b1b3cc7b9949af3b35089931f2b58c289163" uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" version = "0.4.12" +weakdeps = ["ReverseDiff"] + + [deps.PreallocationTools.extensions] + PreallocationToolsReverseDiffExt = "ReverseDiff" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -743,6 +893,16 @@ git-tree-sha1 = "7ed35fb5f831aaf09c2d7c8736d44667a1afdcb0" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" version = "2.38.7" + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" @@ -811,14 +971,18 @@ uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" version = "1.1.1" [[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] -deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.3.1" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" [[deps.SplittablesBase]] deps = ["Setfield", "Test"] @@ -827,10 +991,14 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +deps = ["LinearAlgebra", "Random", "StaticArraysCore"] git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" uuid = "90137ffa-7385-5640-81b9-e52037218182" version = "1.6.2" +weakdeps = ["Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" @@ -840,6 +1008,7 @@ version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -863,6 +1032,11 @@ version = "0.6.15" deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+6" + [[deps.SymbolicIndexingInterface]] deps = ["DocStringExtensions"] git-tree-sha1 = "f8ab052bfcbdb9b48fad2c80c873aa0d0344dfe5" @@ -878,7 +1052,7 @@ version = "1.2.0" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" +version = "1.0.3" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -895,7 +1069,7 @@ version = "1.10.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.1" +version = "1.10.0" [[deps.TaylorDiff]] deps = ["ChainRules", "ChainRulesCore", "ChainRulesOverloadGeneration", "SymbolicUtils", "Zygote"] @@ -909,6 +1083,12 @@ git-tree-sha1 = "50718b4fc1ce20cecf28d85215028c78b4d875c2" uuid = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea" version = "0.15.2" + [deps.TaylorSeries.extensions] + TaylorSeriesIAExt = "IntervalArithmetic" + + [deps.TaylorSeries.weakdeps] + IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" + [[deps.TerminalLoggers]] deps = ["LeftChildRightSiblingTrees", "Logging", "Markdown", "Printf", "ProgressLogging", "UUIDs"] git-tree-sha1 = "f133fab380933d042f6796eda4e130272ba520ca" @@ -931,6 +1111,12 @@ git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" version = "0.2.26" + [deps.Tracker.extensions] + TrackerPDMatsExt = "PDMats" + + [deps.Tracker.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + [[deps.TranscodingStreams]] deps = ["Random", "Test"] git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" @@ -943,6 +1129,20 @@ git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" version = "0.4.78" + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + [[deps.TruncatedStacktraces]] deps = ["InteractiveUtils", "MacroTools", "Preferences"] git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" @@ -987,7 +1187,7 @@ version = "0.1.1" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+3" +version = "1.2.13+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -995,6 +1195,16 @@ git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" version = "0.6.63" + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" @@ -1008,9 +1218,9 @@ uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" version = "1.1.0" [[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.1.1+0" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/src/derivative.jl b/src/derivative.jl index 98191570..c5f2b0b5 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -4,51 +4,62 @@ export derivative """ derivative(f, x::T, order::Int64) - derivative(f, x::AbstractMatrix{T}, order::Int64) derivative(f, x::T, ::Val{N}) - derivative(f, x::AbstractMatrix{T}, ::Val{N}) -Computes `order`-th derivative of `f` w.r.t. `x`. +Computes `order`-th derivative of `f` w.r.t. scalar `x`. derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64) - derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64) derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N}) + +Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. + + derivative(f, x::AbstractMatrix{T}, order::Int64) + derivative(f, x::AbstractMatrix{T}, ::Val{N}) + derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64) derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N}) -Computes `order`-th directional derivative of `f` w.r.t. `x` in direction `l`. +Shorthand notations for multiple calculations. +For a M-by-N matrix, calculate the directional derivative for each column. +For a 1-by-N matrix (row vector), calculate the derivative for each scalar. """ function derivative end -@inline function derivative(f, x::Union{T, AbstractMatrix{T}}, - order::Int64) where {T <: Number} +# Convenience wrappers for converting orders to value types +# and forward work to core APIs + +@inline function derivative(f, x, order::Int64) derivative(f, x, Val{order + 1}()) end -@inline function derivative(f, x::Union{AbstractVector{T}, AbstractMatrix{T}}, - l::AbstractVector{S}, order::Int64) where {T <: Number, S <: Number} +@inline function derivative(f, x, l, order::Int64) derivative(f, x, l, Val{order + 1}()) end +# Core APIs + +# Added to help Zygote infer types +make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) + @inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N} t = TaylorScalar{T, N}(x, one(x)) return extract_derivative(f(t), N) end -@inline function derivative(f, x::AbstractMatrix{<:Number}, N::Val) - size(x)[1] != 1 && @warn "x is not a row vector." - mapcols(u -> derivative(f, u[1], N), x) -end - -# Need to rewrite like this to help Zygote infer types -make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) - @inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, vN::Val{N}) where {T <: Number, S <: Number, N} - t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l) + t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) + # equivalent to map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) end -@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, - vN::Val{N}) where {T <: Number, N} +# shorthand notations for matrices + +@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: Number, N} + size(x)[1] != 1 && @warn "x is not a row vector." + mapcols(u -> derivative(f, u[1], vN), x) +end + +@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S}, + vN::Val{N}) where {T <: Number, S <: Number, N} mapcols(u -> derivative(f, u, l, vN), x) end