From 310f0cc5d7334e294780acd4f50c9be071ba8151 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 15:14:17 +0100 Subject: [PATCH 1/8] separate rrules for optional arguments --- src/rulesets/LinearAlgebra/dense.jl | 15 +++++++++++++-- src/rulesets/LinearAlgebra/factorization.jl | 8 +++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 30769eca3..aebc8ec8c 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -194,10 +194,21 @@ function frule((_, ΔA), ::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where return Y, ∂Y end +function rrule( + ::typeof(pinv), + x::Union{AbstractVector{T}, LinearAlgebra.AdjOrTransAbsVec{T}}, +) where {T<:Union{Real,Complex}} + y, full_pb = rrule(pinv, x, 0) + function pinv_pullback(Δy) + return pull_pb(Δy)[1:2] + end + return y, pinv_pullback +end + function rrule( ::typeof(pinv), x::AbstractVector{T}, - tol::Real = 0, + tol::Real, ) where {T<:Union{Real,Complex}} y = pinv(x, tol) function pinv_pullback(Δy) @@ -210,7 +221,7 @@ end function rrule( ::typeof(pinv), x::LinearAlgebra.AdjOrTransAbsVec{T}, - tol::Real = 0, + tol::Real, ) where {T<:Union{Real,Complex}} y = pinv(x, tol) function pinv_pullback(Δy) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 109b16f80..d5cf9d7ee 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -428,8 +428,14 @@ end ##### ##### `cholesky` ##### - function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) + y, full_pb = rrule(cholesky, A, :U) + function cholesky_pullback(ΔC::Tangent) + return full_pb(ΔC)[1:2] + end + return C, cholesky_pullback +end +function rrule(::typeof(cholesky), A::Real, uplo::Symbol) C = cholesky(A, uplo) function cholesky_pullback(ΔC::Tangent) return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() From 9295d81cd6cb051c1f989361aaab8529c0374c6c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 15:14:42 +0100 Subject: [PATCH 2/8] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c90fc1443..e73dec2f2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.8.1" +version = "0.8.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 3e5f953ec58a309f242206492478caf4813fc5c7 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 15:19:32 +0100 Subject: [PATCH 3/8] remove optional argument --- src/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index d5cf9d7ee..a86d842bd 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -428,7 +428,7 @@ end ##### ##### `cholesky` ##### -function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) +function rrule(::typeof(cholesky), A::Real) y, full_pb = rrule(cholesky, A, :U) function cholesky_pullback(ΔC::Tangent) return full_pb(ΔC)[1:2] From a8a13bb6c946c2edb421e02a66e65fab8f95e2e5 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 15:20:05 +0100 Subject: [PATCH 4/8] fix typo --- src/rulesets/LinearAlgebra/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index aebc8ec8c..c85e0d988 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -200,7 +200,7 @@ function rrule( ) where {T<:Union{Real,Complex}} y, full_pb = rrule(pinv, x, 0) function pinv_pullback(Δy) - return pull_pb(Δy)[1:2] + return full_pb(Δy)[1:2] end return y, pinv_pullback end From 5926550cd0455f998781c11937534b1d6acb739c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 16:48:27 +0100 Subject: [PATCH 5/8] short function --- src/rulesets/LinearAlgebra/dense.jl | 4 +--- src/rulesets/LinearAlgebra/factorization.jl | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index c85e0d988..3145123a4 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -199,9 +199,7 @@ function rrule( x::Union{AbstractVector{T}, LinearAlgebra.AdjOrTransAbsVec{T}}, ) where {T<:Union{Real,Complex}} y, full_pb = rrule(pinv, x, 0) - function pinv_pullback(Δy) - return full_pb(Δy)[1:2] - end + pinv_pullback(Δy) = return full_pb(Δy)[1:2] return y, pinv_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index a86d842bd..482b7c237 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -430,9 +430,7 @@ end ##### function rrule(::typeof(cholesky), A::Real) y, full_pb = rrule(cholesky, A, :U) - function cholesky_pullback(ΔC::Tangent) - return full_pb(ΔC)[1:2] - end + cholesky_pullback(ΔC::Tangent) = return full_pb(ΔC)[1:2] return C, cholesky_pullback end function rrule(::typeof(cholesky), A::Real, uplo::Symbol) From 90bb64edc9af7204aa47f08ad7c4bcc144c210c5 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 16:54:25 +0100 Subject: [PATCH 6/8] fix cholesky --- src/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 482b7c237..9bb0f64de 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -429,7 +429,7 @@ end ##### `cholesky` ##### function rrule(::typeof(cholesky), A::Real) - y, full_pb = rrule(cholesky, A, :U) + C, full_pb = rrule(cholesky, A, :U) cholesky_pullback(ΔC::Tangent) = return full_pb(ΔC)[1:2] return C, cholesky_pullback end From ac2704fe0fa208a2a784a97e3102eec3415f50ac Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 16:59:32 +0100 Subject: [PATCH 7/8] do not check inference below 1.5 --- test/rulesets/LinearAlgebra/dense.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 2e70338b4..511e01aeb 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -50,7 +50,8 @@ @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) test_frule(pinv, F(randn(T, 3)) ⊢ F(randn(T, 3))) - test_rrule(pinv, F(randn(T, 3))) + check_inferred = VERSION ≥ v"1.5" + test_rrule(pinv, F(randn(T, 3)); check_inferred=check_inferred) # Check types. # TODO: Do we need this still? From ca7005d71af399c6e9d67b3f40e4236a34e98bde Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 4 Jun 2021 18:03:49 +0100 Subject: [PATCH 8/8] do not check_inferred below 1.5 --- test/rulesets/LinearAlgebra/factorization.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 397f1622f..534c45be4 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -357,7 +357,8 @@ end # also we might be missing some overloads for different tangent-types in the rules @testset "cholesky" begin @testset "Real" begin - test_rrule(cholesky, 0.8) + check_inferred = VERSION ≥ v"1.5" + test_rrule(cholesky, 0.8; check_inferred=check_inferred) end @testset "Diagonal{<:Real}" begin D = Diagonal(rand(5) .+ 0.1)