From 2ab452ba28cba9cc7b3b235117144974070f9211 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 16 Mar 2025 15:43:45 -0400 Subject: [PATCH 001/126] implement `foreachblock` --- src/tensors/blockiterator.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index 7929b5d19..8f291dd13 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -12,3 +12,29 @@ Base.IteratorSize(::BlockIterator) = Base.HasLength() Base.IteratorEltype(::BlockIterator) = Base.HasEltype() Base.eltype(::Type{<:BlockIterator{T}}) where {T} = Pair{sectortype(T),blocktype(T)} Base.length(iter::BlockIterator) = length(iter.structure) +Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...) + +# TODO: fast-path when structures are the same? +# TODO: do we want f(c, bs...) or f(c, bs)? +# TODO: implement scheduler +# TODO: do we prefer `blocks(t, ts...)` instead or as well? +""" + foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; [scheduler]) + +Apply `f` to each block of `t` and the corresponding blocks of `ts`. +Optionally, `scheduler` can be used to parallelize the computation. +This function is equivalent to the following loop: + +```julia +for (c, b) in blocks(t) + bs = (b, block.(ts, c)...) + f(c, bs) +end +``` +""" +function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler=nothing) + foreach(blocks(t)) do (c, b) + return f(c, (b, block.(ts, c)...)) + end + return nothing +end From eeb0ac79164b57f3f74aa404f97b99e581bd5226 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 16 Mar 2025 15:44:08 -0400 Subject: [PATCH 002/126] Implement `eig_full!` --- Project.toml | 2 ++ src/TensorKit.jl | 3 ++ src/tensors/matrixalgebrakit.jl | 49 +++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+) create mode 100644 src/tensors/matrixalgebrakit.jl diff --git a/Project.toml b/Project.toml index 7e30fd805..3e8099c65 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.14.11" [deps] LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" @@ -30,6 +31,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" +MatrixAlgebraKit = "0.1.1" PackageExtensionCompat = "1" Random = "1" Strided = "2" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 96a579e25..6f89475c6 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -104,6 +104,8 @@ using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend const TO = TensorOperations +using MatrixAlgebraKit: MatrixAlgebraKit as MAK + using LRUCache using TensorKitSectors @@ -212,6 +214,7 @@ include("tensors/treetransformers.jl") include("tensors/indexmanipulations.jl") include("tensors/diagonal.jl") include("tensors/truncation.jl") +include("tensors/matrixalgebrakit.jl") include("tensors/factorizations.jl") include("tensors/braidingtensor.jl") diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl new file mode 100644 index 000000000..3f5430fb0 --- /dev/null +++ b/src/tensors/matrixalgebrakit.jl @@ -0,0 +1,49 @@ +function MAK.copy_input(::typeof(MAK.eig_full), t::AbstractTensorMap) + return copy_oftype(t, factorisation_scalartype(MAK.eig_full!, t)) +end + +function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) + T = complex(scalartype(t)) + return promote_type(ComplexF32, typeof(zero(T) / sqrt(abs2(one(T))))) +end + +function MAK.check_input(::typeof(MAK.eig_full!), t::AbstractTensorMap, (D, V)) + domain(t) == codomain(t) || + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + Tc = complex(scalartype(t)) + + (D isa DiagonalTensorMap && + scalartype(D) == Tc && + fuse(domain(t)) == space(D, 1)) || + throw(ArgumentError("`eig_full!` requires diagonal tensor D with isomorphic domain and complex `scalartype`")) + + V isa AbstractTensorMap && + scalartype(V) == Tc && + space(V) == (codomain(t) ← codomain(D)) || + throw(ArgumentError("`eig_full!` requires square tensor V with isomorphic domain and complex `scalartype`")) + + return nothing +end + +function MAK.initialize_output(::typeof(MAK.eig_full!), t::AbstractTensorMap, + ::MAK.LAPACK_EigAlgorithm) + Tc = complex(scalartype(t)) + V_diag = fuse(domain(t)) + return DiagonalTensorMap{Tc}(undef, V_diag), similar(t, Tc, domain(t) ← V_diag) +end + +function MAK.eig_full!(t::AbstractTensorMap, (D, V), alg::MAK.LAPACK_EigAlgorithm) + MAK.check_input(MAK.eig_full!, t, (D, V)) + foreachblock(t, D, V) do (_, (b, d, v)) + d′, v′ = MAK.eig_full!(b, (d, v), alg) + # deal with the case where the output is not the same as the input + d === d′ || copyto!(d, d′) + v === v′ || copyto!(v, v′) + return nothing + end + return D, V +end + +function MAK.default_eig_algorithm(::TensorMap{<:LinearAlgebra.BlasFloat}; kwargs...) + return MAK.LAPACK_Expert(; kwargs...) +end From d9697049d047cf62cb9e4696f80bed98af40d26e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 16 Mar 2025 15:46:50 -0400 Subject: [PATCH 003/126] Use `eig_full` in `eig` --- src/tensors/factorizations.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 9a6bf5dd1..cdd443ad4 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -222,10 +222,7 @@ matrices. See the corresponding documentation for more information. See also `eigen` and `eigh`. """ -function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(eig, t), p) - return eig!(tcopy; kwargs...) -end +eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...) = MAK.eig_full(t; kwargs...) """ eigh(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple) -> D, V From d0bf282be6de7e862a0e3d4fa582e944aa3058b4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 19 Mar 2025 11:54:20 -0400 Subject: [PATCH 004/126] Fix factorisation scalartype --- src/tensors/matrixalgebrakit.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index 3f5430fb0..80ff4702c 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -4,7 +4,7 @@ end function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) T = complex(scalartype(t)) - return promote_type(ComplexF32, typeof(zero(T) / sqrt(abs2(one(T))))) + return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) end function MAK.check_input(::typeof(MAK.eig_full!), t::AbstractTensorMap, (D, V)) From 9ca844abd2ccfe82e094367fde6a817326c0a8bd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 19 Mar 2025 12:14:32 -0400 Subject: [PATCH 005/126] Add scheduler support --- Project.toml | 9 +++++++++ src/TensorKit.jl | 3 +++ src/tensors/backends.jl | 28 ++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+) create mode 100644 src/tensors/backends.jl diff --git a/Project.toml b/Project.toml index 3e8099c65..4861dea5c 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,11 @@ version = "0.14.11" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" +OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" @@ -32,8 +35,14 @@ FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" MatrixAlgebraKit = "0.1.1" +OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" +<<<<<<< HEAD +======= +ScopedValues = "1.3.0" +SparseArrays = "1" +>>>>>>> 3f9c871 (Add scheduler support) Strided = "2" TensorKitSectors = "0.1.4, 0.2" TensorOperations = "5.1" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 6f89475c6..7813da24a 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -107,6 +107,8 @@ const TO = TensorOperations using MatrixAlgebraKit: MatrixAlgebraKit as MAK using LRUCache +using OhMyThreads +using ScopedValues using TensorKitSectors import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗ @@ -204,6 +206,7 @@ end #------------------------------------- # general definitions include("tensors/abstracttensor.jl") +include("tensors/backends.jl") include("tensors/blockiterator.jl") include("tensors/tensor.jl") include("tensors/adjoint.jl") diff --git a/src/tensors/backends.jl b/src/tensors/backends.jl new file mode 100644 index 000000000..2f945bc42 --- /dev/null +++ b/src/tensors/backends.jl @@ -0,0 +1,28 @@ +# Scheduler implementation +# ------------------------ +function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...) + return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs) + Threads.nthreads() > 1 ? SerialScheduler() : DynamicScheduler() + else + OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...) + end +end + +""" + const blockscheduler = ScopedValue{Scheduler}(SerialScheduler()) + +The default scheduler used when looping over different blocks in the matrix representation of a +tensor. +For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref). +""" +const blockscheduler = ScopedValue{Scheduler}(SerialScheduler()) + +""" + with_blockscheduler(f, [scheduler]; kwargs...) + +Run `f` in a scope where the `blockscheduler` is determined by `scheduler' and `kwargs...`. +""" +@inline function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven(); + kwargs...) + @with blockscheduler => select_scheduler(scheduler; kwargs...) f() +end From d9acdb6f63462ca8915e1e7c6c3c9e249c06badc Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 19 Mar 2025 15:59:32 -0400 Subject: [PATCH 006/126] Add BlockAlgorithm --- src/tensors/backends.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/tensors/backends.jl b/src/tensors/backends.jl index 2f945bc42..40cbd1ba5 100644 --- a/src/tensors/backends.jl +++ b/src/tensors/backends.jl @@ -26,3 +26,18 @@ Run `f` in a scope where the `blockscheduler` is determined by `scheduler' and ` kwargs...) @with blockscheduler => select_scheduler(scheduler; kwargs...) f() end + +# TODO: disable for trivial symmetry or small tensors? +default_blockscheduler(t::AbstractTensorMap) = blockscheduler[] + +# MatrixAlgebraKit +# ---------------- +""" + BlockAlgorithm{A,S}(alg, scheduler) + +Generic wrapper for implementing block-wise algorithms. +""" +struct BlockAlgorithm{A,S} <: MatrixAlgebraKit.AbstractAlgorithm + alg::A + scheduler::S +end From 52c193e82aa4eef9583ebf99f54284a699a66a8f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 19 Mar 2025 15:59:58 -0400 Subject: [PATCH 007/126] Add more matrixalgebra methods --- src/TensorKit.jl | 3 +- src/tensors/blockiterator.jl | 2 +- src/tensors/matrixalgebrakit.jl | 207 ++++++++++++++++++++++++++++---- 3 files changed, 187 insertions(+), 25 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 7813da24a..48b5938fd 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -121,7 +121,7 @@ using Base: @boundscheck, @propagate_inbounds, @constprop, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype using Base.Iterators: product, filter -using LinearAlgebra: LinearAlgebra +using LinearAlgebra: LinearAlgebra, BlasFloat using LinearAlgebra: norm, dot, normalize, normalize!, tr, axpy!, axpby!, lmul!, rmul!, mul!, ldiv!, rdiv!, adjoint, adjoint!, transpose, transpose!, @@ -129,6 +129,7 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr, eigen, eigen!, svd, svd!, isposdef, isposdef!, ishermitian, rank, cond, Diagonal, Hermitian +using MatrixAlgebraKit import Base.Meta diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index 8f291dd13..008d9fdf2 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -34,7 +34,7 @@ end """ function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler=nothing) foreach(blocks(t)) do (c, b) - return f(c, (b, block.(ts, c)...)) + return f(c, (b, map(Base.Fix2(block, c), ts)...)) end return nothing end diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index 80ff4702c..36993331d 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -1,49 +1,210 @@ -function MAK.copy_input(::typeof(MAK.eig_full), t::AbstractTensorMap) - return copy_oftype(t, factorisation_scalartype(MAK.eig_full!, t)) +# Generic +# ------- +for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, :svd_full, + :svd_compact, :svd_vals, :svd_trunc) + @eval function MatrixAlgebraKit.copy_input(::typeof($f), + t::AbstractTensorMap{<:BlasFloat}) + T = factorisation_scalartype($f, t) + return copy_oftype(t, T) + end +end + +# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) +# T = scalartype(t) +# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) +# end + +# Singular value decomposition +# ---------------------------- +function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)) + V_cod = fuse(codomain(t)) + V_dom = fuse(domain(t)) + + (U isa AbstractTensorMap && + scalartype(U) == scalartype(t) && + space(U) == (codomain(t) ← V_cod)) || + throw(ArgumentError("`svd_full!` requires unitary tensor U with same `scalartype`")) + (S isa AbstractTensorMap && + scalartype(S) == real(scalartype(t)) && + space(S) == (V_cod ← V_dom)) || + throw(ArgumentError("`svd_full!` requires rectangular tensor S with real `scalartype`")) + (Vᴴ isa AbstractTensorMap && + scalartype(Vᴴ) == scalartype(t) && + space(Vᴴ) == (V_dom ← domain(t))) || + throw(ArgumentError("`svd_full!` requires unitary tensor Vᴴ with same `scalartype`")) + + return nothing +end + +function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorMap, + (U, S, Vᴴ)) + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) + + (U isa AbstractTensorMap && + scalartype(U) == scalartype(t) && + space(U) == (codomain(t) ← V_cod)) || + throw(ArgumentError("`svd_compact!` requires isometric tensor U with same `scalartype`")) + (S isa DiagonalTensorMap && + scalartype(S) == real(scalartype(t)) && + space(S) == (V_cod ← V_dom)) || + throw(ArgumentError("`svd_compact!` requires diagonal tensor S with real `scalartype`")) + (Vᴴ isa AbstractTensorMap && + scalartype(Vᴴ) == scalartype(t) && + space(Vᴴ) == (V_dom ← domain(t))) || + throw(ArgumentError("`svd_compact!` requires isometric tensor Vᴴ with same `scalartype`")) + + return nothing +end + +# TODO: svd_vals + +function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_cod = fuse(codomain(t)) + V_dom = fuse(domain(t)) + U = similar(t, domain(t) ← V_cod) + S = similar(t, real(scalartype(t)), V_cod ← V_dom) + Vᴴ = similar(t, domain(t) ← V_dom) + return U, S, Vᴴ +end + +function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) + U = similar(t, domain(t) ← V_cod) + S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod ← V_dom) + Vᴴ = similar(t, domain(t) ← V_dom) + return U, S, Vᴴ +end + +# TODO: svd_vals + +function MatrixAlgebraKit.svd_full!(t::AbstractTensorMap, (U, S, Vᴴ), + alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(svd_full!, t, (U, S, Vᴴ)) + + foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ) + if isempty(b) # TODO: remove once MatrixAlgebraKit supports empty matrices + one!(length(u) > 0 ? u : vᴴ) + zerovector!(s) + else + u′, s′, vᴴ′ = MatrixAlgebraKit.svd_full!(b, (u, s, vᴴ), alg.alg) + # deal with the case where the output is not the same as the input + u === u′ || copyto!(u, u′) + s === s′ || copyto!(s, s′) + vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′) + end + return nothing + end + + return U, S, Vᴴ +end + +function MatrixAlgebraKit.svd_compact!(t::AbstractTensorMap, (U, S, Vᴴ), + alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(svd_compact!, t, (U, S, Vᴴ)) + + foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ) + u′, s′, vᴴ′ = svd_compact!(b, (u, s, vᴴ), alg.alg) + # deal with the case where the output is not the same as the input + u === u′ || copyto!(u, u′) + s === s′ || copyto!(s, s′) + vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′) + return nothing + end + + return U, S, Vᴴ +end + +function MatrixAlgebraKit.default_svd_algorithm(t::AbstractTensorMap{<:BlasFloat}; + scheduler=default_blockscheduler(t), + kwargs...) + return BlockAlgorithm(LAPACK_DivideAndConquer(; kwargs...), scheduler) end -function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) - T = complex(scalartype(t)) - return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) +# Eigenvalue decomposition +# ------------------------ +function MatrixAlgebraKit.check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)) + domain(t) == codomain(t) || + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + + V_D = fuse(domain(t)) + + (D isa DiagonalTensorMap && + scalartype(D) == real(scalartype(t)) && + V_D == space(D, 1)) || + throw(ArgumentError("`eigh_full!` requires diagonal tensor D with isomorphic domain and real `scalartype`")) + + V isa AbstractTensorMap && + scalartype(V) == scalartype(t) && + space(V) == (codomain(t) ← V_D) || + throw(ArgumentError("`eigh_full!` requires square tensor V with isomorphic domain and equal `scalartype`")) + + return nothing end -function MAK.check_input(::typeof(MAK.eig_full!), t::AbstractTensorMap, (D, V)) +function MatrixAlgebraKit.check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) Tc = complex(scalartype(t)) + V_D = fuse(domain(t)) (D isa DiagonalTensorMap && scalartype(D) == Tc && - fuse(domain(t)) == space(D, 1)) || + V_D == space(D, 1)) || throw(ArgumentError("`eig_full!` requires diagonal tensor D with isomorphic domain and complex `scalartype`")) V isa AbstractTensorMap && scalartype(V) == Tc && - space(V) == (codomain(t) ← codomain(D)) || + space(V) == (codomain(t) ← V_D) || throw(ArgumentError("`eig_full!` requires square tensor V with isomorphic domain and complex `scalartype`")) return nothing end -function MAK.initialize_output(::typeof(MAK.eig_full!), t::AbstractTensorMap, - ::MAK.LAPACK_EigAlgorithm) +function MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_D = fuse(domain(t)) + T = real(scalartype(t)) + D = DiagonalTensorMap{T}(undef, V_D) + V = similar(t, codomain(t) ← V_D) + return D, V +end + +function MatrixAlgebraKit.initialize_output(::typeof(eig_full!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_D = fuse(domain(t)) Tc = complex(scalartype(t)) - V_diag = fuse(domain(t)) - return DiagonalTensorMap{Tc}(undef, V_diag), similar(t, Tc, domain(t) ← V_diag) + D = DiagonalTensorMap{Tc}(undef, V_D) + V = similar(t, Tc, codomain(t) ← V_D) + return D, V end -function MAK.eig_full!(t::AbstractTensorMap, (D, V), alg::MAK.LAPACK_EigAlgorithm) - MAK.check_input(MAK.eig_full!, t, (D, V)) - foreachblock(t, D, V) do (_, (b, d, v)) - d′, v′ = MAK.eig_full!(b, (d, v), alg) - # deal with the case where the output is not the same as the input - d === d′ || copyto!(d, d′) - v === v′ || copyto!(v, v′) - return nothing +for f in (:eigh_full!, :eig_full!) + @eval function MatrixAlgebraKit.$f(t::AbstractTensorMap, (D, V), + alg::BlockAlgorithm) + MatrixAlgebraKit.check_input($f, t, (D, V)) + + foreachblock(t, D, V; alg.scheduler) do _, (b, d, v) + d′, v′ = $f(b, (d, v), alg.alg) + # deal with the case where the output is not the same as the input + d === d′ || copyto!(d, d′) + v === v′ || copyto!(v, v′) + return nothing + end + + return D, V end - return D, V end -function MAK.default_eig_algorithm(::TensorMap{<:LinearAlgebra.BlasFloat}; kwargs...) - return MAK.LAPACK_Expert(; kwargs...) +function MatrixAlgebraKit.default_eig_algorithm(t::AbstractTensorMap{<:BlasFloat}; + scheduler=default_blockscheduler(t), + kwargs...) + return BlockAlgorithm(LAPACK_Expert(; kwargs...), scheduler) +end +function MatrixAlgebraKit.default_eigh_algorithm(t::AbstractTensorMap{<:BlasFloat}; + scheduler=default_blockscheduler(t), + kwargs...) + return BlockAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...), + scheduler) end From 4e2b0cfc06bb61f650597bf71be9dcb21e0bd190 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 20 Mar 2025 09:12:38 -0400 Subject: [PATCH 008/126] Start switching more factorizations over --- src/tensors/factorizations.jl | 102 +++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 45 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index cdd443ad4..a421ff57b 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -222,7 +222,10 @@ matrices. See the corresponding documentation for more information. See also `eigen` and `eigh`. """ -eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...) = MAK.eig_full(t; kwargs...) +function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + tcopy = permutedcopy_oftype(t, factorisation_scalartype(eig, t), p) + return eig!(tcopy; kwargs...) +end """ eigh(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple) -> D, V @@ -528,6 +531,12 @@ function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) end # implementation dispatches on algorithm +function _tsvd!(t::TensorMap{<:BlasFloat}, alg::Union{SVD,SDD}, + ::NoTruncation, p::Real=2) + scheduler = default_blockscheduler(t) + svd_alg = alg isa SDD ? LAPACK_DivideAndConquer() : LAPACK_QRIteration() + return MatrixAlgebraKit.svd_compact!(t; alg=BlockAlgorithm(svd_alg, scheduler)) +end function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2) # early return @@ -614,50 +623,53 @@ function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwarg for (c, b) in blocks(t)) end -function eigh!(t::TensorMap{<:RealOrComplexFloat}) - InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) - domain(t) == codomain(t) || - throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same")) - - T = scalartype(t) - I = sectortype(t) - S = spacetype(t) - dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t)) - W = S(dims) - - Tr = real(T) - A = similarstoragetype(t, Tr) - D = DiagonalTensorMap{Tr,S,A}(undef, W) - V = similar(t, domain(t) ← W) - for (c, b) in blocks(t) - values, vectors = MatrixAlgebra.eigh!(b) - copy!(block(D, c), Diagonal(values)) - copy!(block(V, c), vectors) - end - return D, V -end - -function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...) - domain(t) == codomain(t) || - throw(SpaceMismatch("`eig!` requires domain and codomain to be the same")) - - T = scalartype(t) - I = sectortype(t) - S = spacetype(t) - dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t)) - W = S(dims) - - Tc = complex(T) - A = similarstoragetype(t, Tc) - D = DiagonalTensorMap{Tc,S,A}(undef, W) - V = similar(t, Tc, domain(t) ← W) - for (c, b) in blocks(t) - values, vectors = MatrixAlgebra.eig!(b; kwargs...) - copy!(block(D, c), Diagonal(values)) - copy!(block(V, c), vectors) - end - return D, V -end +eigh!(t::TensorMap{<:RealOrComplexFloat}) = eigh_full!(t) +eig!(t::TensorMap{<:RealOrComplexFloat}) = eig_full!(t) + +# function eigh!(t::TensorMap{<:RealOrComplexFloat}) +# InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) +# domain(t) == codomain(t) || +# throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same")) + +# T = scalartype(t) +# I = sectortype(t) +# S = spacetype(t) +# dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t)) +# W = S(dims) + +# Tr = real(T) +# A = similarstoragetype(t, Tr) +# D = DiagonalTensorMap{Tr,S,A}(undef, W) +# V = similar(t, domain(t) ← W) +# for (c, b) in blocks(t) +# values, vectors = MatrixAlgebra.eigh!(b) +# copy!(block(D, c), Diagonal(values)) +# copy!(block(V, c), vectors) +# end +# return D, V +# end + +# function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...) +# domain(t) == codomain(t) || +# throw(SpaceMismatch("`eig!` requires domain and codomain to be the same")) + +# T = scalartype(t) +# I = sectortype(t) +# S = spacetype(t) +# dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t)) +# W = S(dims) + +# Tc = complex(T) +# A = similarstoragetype(t, Tc) +# D = DiagonalTensorMap{Tc,S,A}(undef, W) +# V = similar(t, Tc, domain(t) ← W) +# for (c, b) in blocks(t) +# values, vectors = MatrixAlgebra.eig!(b; kwargs...) +# copy!(block(D, c), Diagonal(values)) +# copy!(block(V, c), vectors) +# end +# return D, V +# end #--------------------------------------------------# # Checks for hermiticity and positive definiteness # From 774b11ec0a1c535fb7b3ea887eed1a45a0c0d478 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 20 Mar 2025 09:32:07 -0400 Subject: [PATCH 009/126] Improve `svd` error messages --- src/tensors/matrixalgebrakit.jl | 69 ++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index 36993331d..9f563ca8f 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -9,6 +9,18 @@ for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, end end +# TODO: move to MatrixAlgebraKit? +macro check_eltype(x, y, f=:identity, g=:eltype) + msg = "unexpected scalar type: " + msg *= string(g) * "(" * string(x) * ") != " + if f == :identity + msg *= string(g) * "(" * string(y) * ")" + else + msg *= string(f) * "(" * string(y) * ")" + end + return :($g($x) == $f($g($y)) || throw(ArgumentError($msg))) +end + # function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) # T = scalartype(t) # return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) @@ -16,42 +28,43 @@ end # Singular value decomposition # ---------------------------- -function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)) +const T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} + +function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, + (U, S, Vᴴ)::T_USVᴴ) + # scalartype checks + @check_eltype U t + @check_eltype S t real + @check_eltype Vᴴ t + + # space checks V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) - - (U isa AbstractTensorMap && - scalartype(U) == scalartype(t) && - space(U) == (codomain(t) ← V_cod)) || - throw(ArgumentError("`svd_full!` requires unitary tensor U with same `scalartype`")) - (S isa AbstractTensorMap && - scalartype(S) == real(scalartype(t)) && - space(S) == (V_cod ← V_dom)) || - throw(ArgumentError("`svd_full!` requires rectangular tensor S with real `scalartype`")) - (Vᴴ isa AbstractTensorMap && - scalartype(Vᴴ) == scalartype(t) && - space(Vᴴ) == (V_dom ← domain(t))) || - throw(ArgumentError("`svd_full!` requires unitary tensor Vᴴ with same `scalartype`")) + space(U) == (codomain(t) ← V_cod) || + throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← fuse(domain(t)))`")) + space(S) == (V_cod ← V_dom) || + throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(S) == (fuse(codomain(t)) ← fuse(domain(t))`")) + space(Vᴴ) == (V_dom ← domain(t)) || + throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (fuse(domain(t)) ← domain(t))`")) return nothing end function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorMap, - (U, S, Vᴴ)) - V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) + (U, S, Vᴴ)::T_USVᴴ) + # scalartype checks + @check_eltype U t + @check_eltype S t real + @check_eltype Vᴴ t - (U isa AbstractTensorMap && - scalartype(U) == scalartype(t) && - space(U) == (codomain(t) ← V_cod)) || - throw(ArgumentError("`svd_compact!` requires isometric tensor U with same `scalartype`")) - (S isa DiagonalTensorMap && - scalartype(S) == real(scalartype(t)) && - space(S) == (V_cod ← V_dom)) || - throw(ArgumentError("`svd_compact!` requires diagonal tensor S with real `scalartype`")) - (Vᴴ isa AbstractTensorMap && - scalartype(Vᴴ) == scalartype(t) && - space(Vᴴ) == (V_dom ← domain(t))) || - throw(ArgumentError("`svd_compact!` requires isometric tensor Vᴴ with same `scalartype`")) + # space checks + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) + space(U) == (codomain(t) ← V_cod) || + throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← infimum(fuse(domain(t)), fuse(codomain(t)))`")) + space(S) == (V_cod ← V_dom) || + throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires diagonal `S` with `domain(S) == (infimum(fuse(codomain(t)), fuse(domain(t)))`")) + space(Vᴴ) == (V_dom ← domain(t)) || + throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(domain(t)), fuse(codomain(t))) ← domain(t))`")) return nothing end From 1b47a8a5af08d6f0da4c47ce2cf60ba2c24db910 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 20 Mar 2025 09:37:59 -0400 Subject: [PATCH 010/126] more error msg improvements --- src/tensors/matrixalgebrakit.jl | 52 +++++++++++++++++---------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index 9f563ca8f..7858e6c8f 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -28,10 +28,11 @@ end # Singular value decomposition # ---------------------------- -const T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} +const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} +const _T_USVᴴ_diag = Tuple{<:AbstractTensorMap,<:DiagonalTensorMap,<:AbstractTensorMap} function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, - (U, S, Vᴴ)::T_USVᴴ) + (U, S, Vᴴ)::_T_USVᴴ) # scalartype checks @check_eltype U t @check_eltype S t real @@ -51,7 +52,7 @@ function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, end function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorMap, - (U, S, Vᴴ)::T_USVᴴ) + (U, S, Vᴴ)::_T_USVᴴ_diag) # scalartype checks @check_eltype U t @check_eltype S t real @@ -137,40 +138,41 @@ end # Eigenvalue decomposition # ------------------------ -function MatrixAlgebraKit.check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)) +const _T_DV = Tuple{<:DiagonalTensorMap,<:AbstractTensorMap} +function MatrixAlgebraKit.check_input(::typeof(eigh_full!), t::AbstractTensorMap, + (D, V)::_T_DV) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - V_D = fuse(domain(t)) - - (D isa DiagonalTensorMap && - scalartype(D) == real(scalartype(t)) && - V_D == space(D, 1)) || - throw(ArgumentError("`eigh_full!` requires diagonal tensor D with isomorphic domain and real `scalartype`")) + # scalartype checks + @check_eltype D t real + @check_eltype V t - V isa AbstractTensorMap && - scalartype(V) == scalartype(t) && - space(V) == (codomain(t) ← V_D) || - throw(ArgumentError("`eigh_full!` requires square tensor V with isomorphic domain and equal `scalartype`")) + # space checks + V_D = fuse(domain(t)) + V_D == space(D, 1) || + throw(SpaceMismatch("`eigh_full!(t, (D, V))` requires diagonal `D` with `domain(D) == fuse(domain(t))`")) + space(V) == (codomain(t) ← V_D) || + throw(SpaceMismatch("`eigh_full!(t, (D, V))` requires `space(V) == (codomain(t) ← fuse(domain(t)))`")) return nothing end -function MatrixAlgebraKit.check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)) +function MatrixAlgebraKit.check_input(::typeof(eig_full!), t::AbstractTensorMap, + (D, V)::_T_DV) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - Tc = complex(scalartype(t)) - V_D = fuse(domain(t)) - (D isa DiagonalTensorMap && - scalartype(D) == Tc && - V_D == space(D, 1)) || - throw(ArgumentError("`eig_full!` requires diagonal tensor D with isomorphic domain and complex `scalartype`")) + # scalartype checks + @check_eltype D t complex + @check_eltype V t complex - V isa AbstractTensorMap && - scalartype(V) == Tc && - space(V) == (codomain(t) ← V_D) || - throw(ArgumentError("`eig_full!` requires square tensor V with isomorphic domain and complex `scalartype`")) + # space checks + V_D = fuse(domain(t)) + V_D == space(D, 1) || + throw(SpaceMismatch("`eig_full!(t, (D, V))` requires diagonal `D` with `domain(D) == fuse(domain(t))`")) + space(V) == (codomain(t) ← V_D) || + throw(SpaceMismatch("`eig_full!(t, (D, V))` requires `space(V) == (codomain(t) ← fuse(domain(t)))`")) return nothing end From e46e0f31a2d00966d74f63b41485102087cb20e7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 20 Mar 2025 11:42:26 -0400 Subject: [PATCH 011/126] Properly escape macro hygiene --- src/tensors/matrixalgebrakit.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index 7858e6c8f..f3fbd8039 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -18,7 +18,7 @@ macro check_eltype(x, y, f=:identity, g=:eltype) else msg *= string(f) * "(" * string(y) * ")" end - return :($g($x) == $f($g($y)) || throw(ArgumentError($msg))) + return esc(:($g($x) == $f($g($y)) || throw(ArgumentError($msg)))) end # function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) From 0f7718bd0ed21c942775cb8ca22d3ab3efc9d383 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 20 Mar 2025 11:44:58 -0400 Subject: [PATCH 012/126] Fix SVD spaces --- src/tensors/matrixalgebrakit.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index f3fbd8039..3a870bae2 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -78,16 +78,16 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTens V_dom = fuse(domain(t)) U = similar(t, domain(t) ← V_cod) S = similar(t, real(scalartype(t)), V_cod ← V_dom) - Vᴴ = similar(t, domain(t) ← V_dom) + Vᴴ = similar(t, V_dom ← domain(t)) return U, S, Vᴴ end function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, ::MatrixAlgebraKit.AbstractAlgorithm) V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) - U = similar(t, domain(t) ← V_cod) - S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod ← V_dom) - Vᴴ = similar(t, domain(t) ← V_dom) + U = similar(t, codomain(t) ← V_cod) + S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) + Vᴴ = similar(t, V_dom ← domain(t)) return U, S, Vᴴ end From 906dba9f2e4df4fa09177d696f7d924a4946689e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 20 Mar 2025 12:44:29 -0400 Subject: [PATCH 013/126] Also return `truncerr` --- src/tensors/factorizations.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index a421ff57b..10adbd623 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -535,7 +535,8 @@ function _tsvd!(t::TensorMap{<:BlasFloat}, alg::Union{SVD,SDD}, ::NoTruncation, p::Real=2) scheduler = default_blockscheduler(t) svd_alg = alg isa SDD ? LAPACK_DivideAndConquer() : LAPACK_QRIteration() - return MatrixAlgebraKit.svd_compact!(t; alg=BlockAlgorithm(svd_alg, scheduler)) + return MatrixAlgebraKit.svd_compact!(t; alg=BlockAlgorithm(svd_alg, scheduler))..., + zero(real(scalartype(t))) end function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2) From 85b5e390fc195e30bd2c0ebefe5abcb88ba715a0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 2 Apr 2025 14:02:00 -0400 Subject: [PATCH 014/126] Add `setdiff` for ElementarySpace --- src/spaces/cartesianspace.jl | 5 +++++ src/spaces/complexspace.jl | 6 ++++++ src/spaces/gradedspace.jl | 6 ++++++ src/spaces/vectorspaces.jl | 8 ++++++++ 4 files changed, 25 insertions(+) diff --git a/src/spaces/cartesianspace.jl b/src/spaces/cartesianspace.jl index f29ed263d..fd38e0c1e 100644 --- a/src/spaces/cartesianspace.jl +++ b/src/spaces/cartesianspace.jl @@ -56,4 +56,9 @@ flip(V::CartesianSpace) = V infimum(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(min(V₁.d, V₂.d)) supremum(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(max(V₁.d, V₂.d)) +function Base.setdiff(V::CartesianSpace, W::CartesianSpace) + V ≿ W || throw(ArgumentError("$(W) is not a subspace of $(V)")) + return CartesianSpace(dim(V) - dim(W)) +end + Base.show(io::IO, V::CartesianSpace) = print(io, "ℝ^$(V.d)") diff --git a/src/spaces/complexspace.jl b/src/spaces/complexspace.jl index ff05888b8..51a3056e9 100644 --- a/src/spaces/complexspace.jl +++ b/src/spaces/complexspace.jl @@ -69,4 +69,10 @@ function supremum(V₁::ComplexSpace, V₂::ComplexSpace) throw(SpaceMismatch("Supremum of space and dual space does not exist")) end +function Base.setdiff(V::ComplexSpace, W::ComplexSpace) + (V ≿ W && isdual(V) == isdual(W)) || + throw(ArgumentError("$(W) is not a subspace of $(V)")) + return ComplexSpace(dim(V) - dim(W), isdual(V)) +end + Base.show(io::IO, V::ComplexSpace) = print(io, isdual(V) ? "(ℂ^$(V.d))'" : "ℂ^$(V.d)") diff --git a/src/spaces/gradedspace.jl b/src/spaces/gradedspace.jl index 63903a1ac..e4a016602 100644 --- a/src/spaces/gradedspace.jl +++ b/src/spaces/gradedspace.jl @@ -183,6 +183,12 @@ function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector} for c in union(sectors(V₁), sectors(V₂)); dual=Visdual) end +function Base.setdiff(V::GradedSpace{I}, W::GradedSpace{I}) where {I<:Sector} + V ≿ W && isdual(V) == isdual(W) || + throw(SpaceMismatch("$(W) is not a subspace of $(V)")) + return typeof(V)(c => dim(V, c) - dim(W, c) for c in sectors(V)) +end + function Base.show(io::IO, V::GradedSpace{I}) where {I<:Sector} print(io, type_repr(typeof(V)), "(") separator = "" diff --git a/src/spaces/vectorspaces.jl b/src/spaces/vectorspaces.jl index 844da081f..3b6e6f08e 100644 --- a/src/spaces/vectorspaces.jl +++ b/src/spaces/vectorspaces.jl @@ -405,3 +405,11 @@ have the same value. function supremum(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} return supremum(supremum(V₁, V₂), V₃...) end + +""" + setdiff(V::ElementarySpace, W::ElementarySpace) + +Return the set difference of two elementary spaces, i.e. an instance `X::ElementarySpace` +such that `V = W ⊕ X`. +""" +Base.setdiff(V₁::S, V₂::S) where {S<:ElementarySpace} From 784af861234f8a061a5764b629133aab3d07548e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 2 Apr 2025 14:06:55 -0400 Subject: [PATCH 015/126] Add `qr_` implementations --- src/tensors/matrixalgebrakit.jl | 393 ++++++++++++++++++++++++++++++++ test/factorizations.jl | 218 ++++++++++++++++++ test/paul.jl | 65 ++++++ 3 files changed, 676 insertions(+) create mode 100644 test/factorizations.jl create mode 100644 test/paul.jl diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index 3a870bae2..dd45d6203 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -223,3 +223,396 @@ function MatrixAlgebraKit.default_eigh_algorithm(t::AbstractTensorMap{<:BlasFloa return BlockAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...), scheduler) end + +# QR decomposition +# ---------------- +function MatrixAlgebraKit.check_input(::typeof(qr_full!), t::AbstractTensorMap, + (Q, + R)::Tuple{<:AbstractTensorMap,<:AbstractTensorMap}) + # scalartype checks + @check_eltype Q t + @check_eltype R t + + # space checks + V_Q = fuse(codomain(t)) + space(Q) == (codomain(t) ← V_Q) || + throw(SpaceMismatch("`qr_full!(t, (Q, R))` requires `space(Q) == (codomain(t) ← fuse(codomain(t)))`")) + space(R) == (V_Q ← domain(t)) || + throw(SpaceMismatch("`qr_full!(t, (Q, R))` requires `space(R) == (fuse(codomain(t)) ← domain(t)`")) + + return nothing +end + +function MatrixAlgebraKit.check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)) + # scalartype checks + @check_eltype Q t + @check_eltype R t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + space(Q) == (codomain(t) ← V_Q) || + throw(SpaceMismatch("`qr_compact!(t, (Q, R))` requires `space(Q) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`")) + space(R) == (V_Q ← domain(t)) || + throw(SpaceMismatch("`qr_compact!(t, (Q, R))` requires `space(R) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`")) + + return nothing +end + +function MatrixAlgebraKit.check_input(::typeof(qr_null!), t::AbstractTensorMap, + N::AbstractTensorMap) + # scalartype checks + @check_eltype N t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = setdiff(fuse(codomain(t)), V_Q) + space(N) == (codomain(t) ← V_N) || + throw(SpaceMismatch("`qr_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) + + return nothing +end + +function MatrixAlgebraKit.initialize_output(::typeof(qr_full!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_Q = fuse(codomain(t)) + Q = similar(t, codomain(t) ← V_Q) + R = similar(t, V_Q ← domain(t)) + return Q, R +end + +function MatrixAlgebraKit.initialize_output(::typeof(qr_compact!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + Q = similar(t, codomain(t) ← V_Q) + R = similar(t, V_Q ← domain(t)) + return Q, R +end + +function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = setdiff(fuse(codomain(t)), V_Q) + N = similar(t, codomain(t) ← V_N) + return N +end + +function MatrixAlgebraKit.qr_full!(t::AbstractTensorMap, (Q, R), + alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(qr_full!, t, (Q, R)) + + foreachblock(t, Q, R; alg.scheduler) do _, (b, q, r) + q′, r′ = qr_full!(b, (q, r), alg.alg) + # deal with the case where the output is not the same as the input + q === q′ || copyto!(q, q′) + r === r′ || copyto!(r, r′) + return nothing + end + + return Q, R +end + +function MatrixAlgebraKit.qr_compact!(t::AbstractTensorMap, (Q, R), + alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(qr_compact!, t, (Q, R)) + + foreachblock(t, Q, R; alg.scheduler) do _, (b, q, r) + q′, r′ = qr_compact!(b, (q, r), alg.alg) + # deal with the case where the output is not the same as the input + q === q′ || copyto!(q, q′) + r === r′ || copyto!(r, r′) + return nothing + end + + return Q, R +end + +function MatrixAlgebraKit.qr_null!(t::AbstractTensorMap, N, alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(qr_null!, t, N) + + foreachblock(t, N; alg.scheduler) do _, (b, n) + n′ = qr_null!(b, n, alg.alg) + # deal with the case where the output is not the same as the input + n === n′ || copyto!(n, n′) + return nothing + end + + return N +end + +function MatrixAlgebraKit.default_qr_algorithm(t::AbstractTensorMap{<:BlasFloat}; + scheduler=default_blockscheduler(t), + kwargs...) + return BlockAlgorithm(LAPACK_HouseholderQR(; kwargs...), scheduler) +end + +# LQ decomposition +# ---------------- +function MatrixAlgebraKit.check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)) + # scalartype checks + @check_eltype L t + @check_eltype Q t + + # space checks + V_Q = fuse(domain(t)) + space(L) == (codomain(t) ← V_Q) || + throw(SpaceMismatch("`lq_full!(t, (L, Q))` requires `space(L) == (codomain(t) ← fuse(domain(t)))`")) + space(Q) == (V_Q ← domain(t)) || + throw(SpaceMismatch("`lq_full!(t, (L, Q))` requires `space(Q) == (fuse(domain(t)) ← domain(t))`")) + + return nothing +end + +function MatrixAlgebraKit.check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)) + # scalartype checks + @check_eltype L t + @check_eltype Q t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + space(L) == (codomain(t) ← V_Q) || + throw(SpaceMismatch("`lq_compact!(t, (L, Q))` requires `space(L) == infimum(fuse(codomain(t)), fuse(domain(t)))`")) + space(Q) == (V_Q ← domain(t)) || + throw(SpaceMismatch("`lq_compact!(t, (L, Q))` requires `space(Q) == infimum(fuse(codomain(t)), fuse(domain(t)))`")) + + return nothing +end + +function MatrixAlgebraKit.check_input(::typeof(lq_null!), t::AbstractTensorMap, N) + # scalartype checks + @check_eltype N t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = setdiff(fuse(domain(t)), V_Q) + space(N) == (V_N ← domain(t)) || + throw(SpaceMismatch("`lq_null!(t, N)` requires `space(N) == setdiff(fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`")) + + return nothing +end + +function MatrixAlgebraKit.initialize_output(::typeof(lq_full!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_Q = fuse(domain(t)) + L = similar(t, codomain(t) ← V_Q) + Q = similar(t, V_Q ← domain(t)) + return L, Q +end + +function MatrixAlgebraKit.initialize_output(::typeof(lq_compact!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + L = similar(t, codomain(t) ← V_Q) + Q = similar(t, V_Q ← domain(t)) + return L, Q +end + +function MatrixAlgebraKit.initialize_output(::typeof(lq_null!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = setdiff(fuse(domain(t)), V_Q) + N = similar(t, V_N ← domain(t)) + return N +end + +function MatrixAlgebraKit.lq_full!(t::AbstractTensorMap, (L, Q), + alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(lq_full!, t, (L, Q)) + + foreachblock(t, L, Q; alg.scheduler) do _, (b, l, q) + l′, q′ = lq_full!(b, (l, q), alg.alg) + # deal with the case where the output is not the same as the input + l === l′ || copyto!(l, l′) + q === q′ || copyto!(q, q′) + return nothing + end + + return L, Q +end + +function MatrixAlgebraKit.lq_compact!(t::AbstractTensorMap, (L, Q), + alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(lq_compact!, t, (L, Q)) + + foreachblock(t, L, Q; alg.scheduler) do _, (b, l, q) + l′, q′ = lq_compact!(b, (l, q), alg.alg) + # deal with the case where the output is not the same as the input + l === l′ || copyto!(l, l′) + q === q′ || copyto!(q, q′) + return nothing + end + + return L, Q +end + +function MatrixAlgebraKit.lq_null!(t::AbstractTensorMap, N, alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(lq_null!, t, N) + + foreachblock(t, N; alg.scheduler) do _, (b, n) + n′ = lq_null!(b, n, alg.alg) + # deal with the case where the output is not the same as the input + n === n′ || copyto!(n, n′) + return nothing + end + + return N +end + +# Polar decomposition +# ------------------- +using MatrixAlgebraKit: PolarViaSVD + +function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P)) + codomain(t) ≿ domain(t) || + throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) + + # scalartype checks + @check_eltype W t + @check_eltype P t + + # space checks + space(W) == (codomain(t) ← fuse(domain(t))) || + throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`")) + space(P) == (fuse(domain(t)) ← domain(t)) || + throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`")) + + return nothing +end + +# TODO: do we really not want to fuse the spaces? +function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap) + W = similar(t, codomain(t) ← fuse(domain(t))) + P = similar(t, fuse(domain(t)) ← domain(t)) + return W, P +end + +function MatrixAlgebraKit.left_polar!(t::AbstractTensorMap, WP, alg::BlockAlgorithm) + MatrixAlgebraKit.check_input(left_polar!, t, WP) + + foreachblock(t, WP...; alg.scheduler) do _, (b, w, p) + w′, p′ = left_polar!(b, (w, p), alg.alg) + # deal with the case where the output is not the same as the input + w === w′ || copyto!(w, w′) + p === p′ || copyto!(p, p′) + return nothing + end + + return WP +end + +function MatrixAlgebraKit.default_polar_algorithm(t::AbstractTensorMap{<:BlasFloat}; + scheduler=default_blockscheduler(t), + kwargs...) + return BlockAlgorithm(PolarViaSVD(LAPACK_DivideAndConquer(; kwargs...)), + scheduler) +end + +# Orthogonalization +# ----------------- +function MatrixAlgebraKit.check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)) + # scalartype checks + @check_eltype V t + isnothing(C) || @check_eltype C t + + # space checks + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + space(V) == (codomain(t) ← V_C) || + throw(SpaceMismatch("`left_orth!(t, (V, C))` requires `space(V) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t))))`")) + isnothing(C) || space(C) == (V_C ← domain(t)) || + throw(SpaceMismatch("`left_orth!(t, (V, C))` requires `space(C) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`")) + + return nothing +end + +function MatrixAlgebraKit.check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)) + # scalartype checks + isnothing(C) || @check_eltype C t + @check_eltype Vᴴ t + + # space checks + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + isnothing(C) || space(C) == (codomain(t) ← V_C) || + throw(SpaceMismatch("`right_orth!(t, (C, Vᴴ))` requires `space(C) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`")) + space(Vᴴ) == (V_dom ← domain(t)) || + throw(SpaceMismatch("`right_orth!(t, (C, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`")) + + return nothing +end + +function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), t::AbstractTensorMap) + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + V = similar(t, codomain(t) ← V_C) + C = similar(t, V_C ← domain(t)) + return V, C +end + +function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTensorMap) + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + C = similar(t, codomain(t) ← V_C) + Vᴴ = similar(t, V_C ← domain(t)) + return C, Vᴴ +end + +function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC; kwargs...) + MatrixAlgebraKit.check_input(left_orth!, t, VC) + atol = get(kwargs, :atol, 0) + rtol = get(kwargs, :rtol, 0) + kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd) + + if !(iszero(atol) && iszero(rtol)) && kind != :svd + throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + end + + if kind == :qr + alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(qr_compact!, t)) + return qr_compact!(t, VC, alg) + elseif kind == :qrpos + alg = get(kwargs, :alg, + MatrixAlgebraKit.select_algorithm(qr_compact!, t; positive=true)) + return qr_compact!(t, VC, alg) + elseif kind == :polar + alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(left_polar!, t)) + return left_polar!(t, VC, alg) + elseif kind == :svd && iszero(atol) && iszero(rtol) + alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(svd_compact!, t)) + V, C = VC + S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C)) + U, S, Vᴴ = svd_compact!(t, (V, S, C), alg) + return U, lmul!(S, Vᴴ) + elseif kind == :svd + alg_svd = MatrixAlgebraKit.select_algorithm(svd_compact!, t) + trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol) + alg = get(kwargs, :alg, MatrixAlgebraKit.TruncatedAlgorithm(alg_svd, trunc)) + V, C = VC + S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C)) + U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg) + return U, lmul!(S, Vᴴ) + else + throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) + end +end + +# Truncation +# ---------- +# TODO: technically we could do this truncation in-place, but this might not be worth it +function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), + trunc::MatrixAlgebraKit.TruncationKeepAbove) + atol = max(trunc.atol, norm(S) * trunc.rtol) + V_truncated = spacetype(S)(c => findlast(>=(atol), b.diag) for (c, b) in blocks(S)) + + Ũ = similar(U, codomain(U) ← V_truncated) + for (c, b) in blocks(Ũ) + copy!(b, @view(block(U, c)[:, 1:size(b, 2)])) + end + + S̃ = DiagonalTensorMap{scalartype(S)}(undef, V_truncated) + for (c, b) in blocks(S̃) + copy!(b.diag, @view(block(S, c).diag[1:size(b, 1)])) + end + + Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) + for (c, b) in blocks(Ṽᴴ) + copy!(b, @view(block(Vᴴ, c)[1:size(b, 1), :])) + end + + return Ũ, S̃, Ṽᴴ +end diff --git a/test/factorizations.jl b/test/factorizations.jl new file mode 100644 index 000000000..a30f005e3 --- /dev/null +++ b/test/factorizations.jl @@ -0,0 +1,218 @@ +for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) + V1, V2, V3, V4, V5 = V + @assert V3 * V4 * V2 ≿ V1' * V5' # necessary for leftorth tests + @assert V3 * V4 ≾ V1' * V2' * V5' # necessary for rightorth tests +end + +spacelist = try + if ENV["CI"] == "true" + println("Detected running on CI") + if Sys.iswindows() + (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂) + elseif Sys.isapple() + (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VfU₁, VfSU₂)#, VSU₃) + else + (Vtr, Vℤ₂, Vfℤ₂, VU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) + end + else + (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) + end +catch + (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) +end + +@timedtestset "Factorizatios with symmetry: $(sectortype(first(V)))" for V in spacelist + V1, V2, V3, V4, V5 = V + W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 + for T in (Float32, ComplexF64), adj in (false, true) + t = adj ? rand(T, W)' : rand(T, W) + @testset "leftorth with $alg" for alg in + (TensorKit.QR(), TensorKit.QRpos(), + TensorKit.QL(), TensorKit.QLpos(), + TensorKit.Polar(), TensorKit.SVD(), + TensorKit.SDD()) + Q, R = @constinferred leftorth(t, ((3, 4, 2), (1, 5)); alg=alg) + QdQ = Q' * Q + @test QdQ ≈ one(QdQ) + @test Q * R ≈ permute(t, ((3, 4, 2), (1, 5))) + if alg isa Polar + @test isposdef(R) + @test domain(R) == codomain(R) == space(t, 1)' ⊗ space(t, 5)' + end + end + @testset "leftnull with $alg" for alg in + (TensorKit.QR(), TensorKit.SVD(), + TensorKit.SDD()) + N = @constinferred leftnull(t, ((3, 4, 2), (1, 5)); alg=alg) + NdN = N' * N + @test NdN ≈ one(NdN) + @test norm(N' * permute(t, ((3, 4, 2), (1, 5)))) < + 100 * eps(norm(t)) + end + @testset "rightorth with $alg" for alg in + (TensorKit.RQ(), TensorKit.RQpos(), + TensorKit.LQ(), TensorKit.LQpos(), + TensorKit.Polar(), TensorKit.SVD(), + TensorKit.SDD()) + L, Q = @constinferred rightorth(t, ((3, 4), (2, 1, 5)); alg=alg) + QQd = Q * Q' + @test QQd ≈ one(QQd) + @test L * Q ≈ permute(t, ((3, 4), (2, 1, 5))) + if alg isa Polar + @test isposdef(L) + @test domain(L) == codomain(L) == space(t, 3) ⊗ space(t, 4) + end + end + @testset "rightnull with $alg" for alg in + (TensorKit.LQ(), TensorKit.SVD(), + TensorKit.SDD()) + M = @constinferred rightnull(t, ((3, 4), (2, 1, 5)); alg=alg) + MMd = M * M' + @test MMd ≈ one(MMd) + @test norm(permute(t, ((3, 4), (2, 1, 5))) * M') < + 100 * eps(norm(t)) + end + @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) + U, S, V = @constinferred tsvd(t, ((3, 4, 2), (1, 5)); alg=alg) + UdU = U' * U + @test UdU ≈ one(UdU) + VVd = V * V' + @test VVd ≈ one(VVd) + t2 = permute(t, ((3, 4, 2), (1, 5))) + @test U * S * V ≈ t2 + + s = LinearAlgebra.svdvals(t2) + s′ = LinearAlgebra.diag(S) + for (c, b) in s + @test b ≈ s′[c] + end + end + @testset "cond and rank" begin + t2 = permute(t, ((3, 4, 2), (1, 5))) + d1 = dim(codomain(t2)) + d2 = dim(domain(t2)) + @test rank(t2) == min(d1, d2) + M = leftnull(t2) + @test rank(M) == max(d1, d2) - min(d1, d2) + t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) + @test cond(t3) ≈ one(real(T)) + @test rank(t3) == dim(V1 ⊗ V2) + t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) + t4 = (t4 + t4') / 2 + vals = LinearAlgebra.eigvals(t4) + λmax = maximum(s -> maximum(abs, s), values(vals)) + λmin = minimum(s -> minimum(abs, s), values(vals)) + @test cond(t4) ≈ λmax / λmin + end + end + @testset "empty tensor" begin + for T in (Float32, ComplexF64) + t = randn(T, V1 ⊗ V2, zero(V1)) + @testset "leftorth with $alg" for alg in + (TensorKit.QR(), TensorKit.QRpos(), + TensorKit.QL(), TensorKit.QLpos(), + TensorKit.Polar(), TensorKit.SVD(), + TensorKit.SDD()) + Q, R = @constinferred leftorth(t; alg=alg) + @test Q == t + @test dim(Q) == dim(R) == 0 + end + @testset "leftnull with $alg" for alg in + (TensorKit.QR(), TensorKit.SVD(), + TensorKit.SDD()) + N = @constinferred leftnull(t; alg=alg) + @test N' * N ≈ id(domain(N)) + @test N * N' ≈ id(codomain(N)) + end + @testset "rightorth with $alg" for alg in + (TensorKit.RQ(), TensorKit.RQpos(), + TensorKit.LQ(), TensorKit.LQpos(), + TensorKit.Polar(), TensorKit.SVD(), + TensorKit.SDD()) + L, Q = @constinferred rightorth(copy(t'); alg=alg) + @test Q == t' + @test dim(Q) == dim(L) == 0 + end + @testset "rightnull with $alg" for alg in + (TensorKit.LQ(), TensorKit.SVD(), + TensorKit.SDD()) + M = @constinferred rightnull(copy(t'); alg=alg) + @test M * M' ≈ id(codomain(M)) + @test M' * M ≈ id(domain(M)) + end + @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) + U, S, V = @constinferred tsvd(t; alg=alg) + @test U == t + @test dim(U) == dim(S) == dim(V) + end + @testset "cond and rank" begin + @test rank(t) == 0 + W2 = zero(V1) * zero(V2) + t2 = rand(W2, W2) + @test rank(t2) == 0 + @test cond(t2) == 0.0 + end + end + end + @testset "eig and isposdef" begin + for T in (Float32, ComplexF64) + t = rand(T, V1 ⊗ V1' ⊗ V2 ⊗ V2') + D, V = eigen(t, ((1, 3), (2, 4))) + t2 = permute(t, ((1, 3), (2, 4))) + @test t2 * V ≈ V * D + + d = LinearAlgebra.eigvals(t2; sortby=nothing) + d′ = LinearAlgebra.diag(D) + for (c, b) in d + @test b ≈ d′[c] + end + + # Somehow moving these test before the previous one gives rise to errors + # with T=Float32 on x86 platforms. Is this an OpenBLAS issue? + VdV = V' * V + VdV = (VdV + VdV') / 2 + @test isposdef(VdV) + + @test !isposdef(t2) # unlikely for non-hermitian map + t2 = (t2 + t2') + D, V = eigen(t2) + VdV = V' * V + @test VdV ≈ one(VdV) + D̃, Ṽ = @constinferred eigh(t2) + @test D ≈ D̃ + @test V ≈ Ṽ + λ = minimum(minimum(real(LinearAlgebra.diag(b))) + for (c, b) in blocks(D)) + @test cond(Ṽ) ≈ one(real(T)) + @test isposdef(t2) == isposdef(λ) + @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) + @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) + end + end + @testset "Tensor truncation" begin + for T in (Float32, ComplexF64), p in (1, 2, 3, Inf), adj in (false, true) + t = adj ? rand(T, V1 ⊗ V2 ⊗ V3, V4 ⊗ V5) : rand(T, V4 ⊗ V5, V1 ⊗ V2 ⊗ V3)' + + U₀, S₀, V₀, = tsvd(t) + t = rmul!(t, 1 / norm(S₀, p)) + U, S, V, ϵ = @constinferred tsvd(t; trunc=truncerr(5e-1), p=p) + # @show p, ϵ + # @show domain(S) + # @test min(space(S,1), space(S₀,1)) != space(S₀,1) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncerr(nextfloat(ϵ)), p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S)))), + p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) + # results with truncationcutoff cannot be compared because they don't take degeneracy into account, and thus truncate differently + U, S, V, ϵ = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀))), p=p) + # @show p, ϵ + # @show domain(S) + # @test min(space(S,1), space(S₀,1)) != space(S₀,1) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) + end + end +end diff --git a/test/paul.jl b/test/paul.jl new file mode 100644 index 000000000..249ed1bae --- /dev/null +++ b/test/paul.jl @@ -0,0 +1,65 @@ +using Zygote, TensorKit + +_safe_pow(a::Real, pow::Real, tol::Real) = (pow < 0 && abs(a) < tol) ? zero(a) : a^pow + +# Element-wise multiplication of TensorMaps respecting block structure +function _elementwise_mult(a₁::AbstractTensorMap, a₂::AbstractTensorMap) + dst = similar(a₁) + for (k, b) in blocks(dst) + copyto!(b, block(a₁, k) .* block(a₂, k)) + end + return dst +end +""" + sdiag_pow(s, pow::Real; tol::Real=eps(scalartype(s))^(3 / 4)) + +Compute `s^pow` for a diagonal matrix `s`. +""" +function sdiag_pow(s::DiagonalTensorMap, pow::Real; tol::Real=eps(scalartype(s))^(3 / 4)) + # Relative tol w.r.t. largest singular value (use norm(∘, Inf) to make differentiable) + tol *= norm(s, Inf) + spow = DiagonalTensorMap(_safe_pow.(s.data, pow, tol), space(s, 1)) + return spow +end +function sdiag_pow(s::AbstractTensorMap{T,S,1,1}, pow::Real; + tol::Real=eps(scalartype(s))^(3 / 4)) where {T,S} + # Relative tol w.r.t. largest singular value (use norm(∘, Inf) to make differentiable) + tol *= norm(s, Inf) + spow = similar(s) + for (k, b) in blocks(s) + copyto!(block(spow, k), + LinearAlgebra.diagm(_safe_pow.(LinearAlgebra.diag(b), pow, tol))) + end + return spow +end + +function ChainRulesCore.rrule(::typeof(sdiag_pow), + s::AbstractTensorMap, + pow::Real; + tol::Real=eps(scalartype(s))^(3 / 4),) + tol *= norm(s, Inf) + spow = sdiag_pow(s, pow; tol) + spow_minus1_conj = scale!(sdiag_pow(s', pow - 1; tol), pow) + function sdiag_pow_pullback(c̄_) + c̄ = unthunk(c̄_) + return (ChainRulesCore.NoTangent(), _elementwise_mult(c̄, spow_minus1_conj)) + end + return spow, sdiag_pow_pullback +end + +function svd_fixed_point(A, U, S, V) + S⁻¹ = sdiag_pow(S, -1) + return (A * V' * S⁻¹ - U, DiagonalTensorMap(U' * A * V' * S⁻¹) - one(S), + S⁻¹ * U' * A - V) +end + +using Zygote + +V = ComplexSpace(3)^2 +A = randn(ComplexF64, V, V) +U, S, V = tsvd(A) + +Zygote.gradient(A, U, S, V) do A, U, S, V + du, ds, dv = svd_fixed_point(A, U, S, V) + return norm(du) + norm(ds) + norm(dv) +end From a9605b65f985ab2a8cc060bc2783426f33b7dd9a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 2 Apr 2025 18:22:43 -0400 Subject: [PATCH 016/126] start adding truncated svd --- src/tensors/matrixalgebrakit.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index dd45d6203..a1f8682c3 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -130,6 +130,12 @@ function MatrixAlgebraKit.svd_compact!(t::AbstractTensorMap, (U, S, Vᴴ), return U, S, Vᴴ end +function MatrixAlgebraKit.svd_trunc!(t::AbstractTensorMap, USVᴴ, + alg::MatrixAlgebraKit.TruncatedAlgorithm) + USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) + return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ′, alg.trunc) +end + function MatrixAlgebraKit.default_svd_algorithm(t::AbstractTensorMap{<:BlasFloat}; scheduler=default_blockscheduler(t), kwargs...) From 77b6fbb85a7b5ebc14111737e96642ad8dbf57dd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 2 Apr 2025 18:25:33 -0400 Subject: [PATCH 017/126] patch through leftorth --- src/tensors/factorizations.jl | 82 ++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 10adbd623..65d6fa83c 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -311,6 +311,13 @@ end #------------------------------------------------------------------------------------------ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} +function _reverse!(t::AbstractTensorMap; dims=:) + for (c, b) in blocks(t) + reverse!(b; dims) + end + return t +end + function leftorth!(t::TensorMap{<:RealOrComplexFloat}; alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(), atol::Real=zero(float(real(scalartype(t)))), @@ -318,47 +325,44 @@ function leftorth!(t::TensorMap{<:RealOrComplexFloat}; eps(real(float(one(scalartype(t))))) * iszero(atol)) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:leftorth!) - if !iszero(rtol) - atol = max(atol, rtol * norm(t)) - end - I = sectortype(t) - dims = SectorDict{I,Int}() - # compute QR factorization for each block - if !isempty(blocks(t)) - generator = Base.Iterators.map(blocks(t)) do (c, b) - Qc, Rc = MatrixAlgebra.leftorth!(b, alg, atol) - dims[c] = size(Qc, 2) - return c => (Qc, Rc) + if alg isa QR + return left_orth!(t; kind=:qr, atol, rtol) + elseif alg isa QRpos + return left_orth!(t; kind=:qrpos, atol, rtol) + elseif alg isa SDD + return left_orth!(t; kind=:svd, atol, rtol) + elseif alg isa Polar + return left_orth!(t; kind=:polar, atol, rtol) + elseif alg isa SVD + kind = :svd + if iszero(atol) && iszero(rtol) + alg′ = LAPACK_QRIteration() + return left_orth!(t; kind, alg=BlockAlgorithm(alg′, default_blockscheduler(t)), + atol, rtol) + else + trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol) + svd_alg = LAPACK_QRIteration() + scheduler = default_blockscheduler(t) + alg′ = MatrixAlgebraKit.TruncatedAlgorithm(BlockAlgorithm(svd_alg, scheduler), + trunc) + return left_orth!(t; kind, alg=alg′, atol, rtol) end - QRdata = SectorDict(generator) + elseif alg isa QL + _reverse!(t; dims=2) + Q, R = left_orth!(t; kind=:qr, atol, rtol) + _reverse!(Q; dims=2) + _reverse!(R) + return Q, R + elseif alg isa QLpos + _reverse!(t; dims=2) + Q, R = left_orth!(t; kind=:qrpos, atol, rtol) + _reverse!(Q; dims=2) + _reverse!(R) + return Q, R end - # construct new space - S = spacetype(t) - V = S(dims) - if alg isa Polar - @assert V ≅ domain(t) - W = domain(t) - elseif length(domain(t)) == 1 && domain(t) ≅ V - W = domain(t) - elseif length(codomain(t)) == 1 && codomain(t) ≅ V - W = codomain(t) - else - W = ProductSpace(V) - end - - # construct output tensors - T = float(scalartype(t)) - Q = similar(t, T, codomain(t) ← W) - R = similar(t, T, W ← domain(t)) - if !isempty(blocks(t)) - for (c, (Qc, Rc)) in QRdata - copy!(block(Q, c), Qc) - copy!(block(R, c), Rc) - end - end - return Q, R + throw(ArgumentError("Algorithm $alg not implemented for leftorth!")) end function leftnull!(t::TensorMap{<:RealOrComplexFloat}; @@ -685,8 +689,8 @@ function LinearAlgebra.ishermitian(t::TensorMap) end function LinearAlgebra.isposdef!(t::TensorMap) - domain(t) == codomain(t) || - throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) + domain(t) ≅ codomain(t) || + throw(SpaceMismatch("`isposdef` requires domain and codomain to be isomorphic")) InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false for (c, b) in blocks(t) isposdef!(b) || return false From 80b4f0e7c46fc6c00b9ea8cdcd2b11c18f68714f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 8 May 2025 16:31:02 -0400 Subject: [PATCH 018/126] Add `isisometry` --- src/TensorKit.jl | 2 +- src/auxiliary/linalg.jl | 2 ++ src/tensors/factorizations.jl | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 48b5938fd..0e7a1cede 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -73,7 +73,7 @@ export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby! export leftorth, rightorth, leftnull, rightnull, leftorth!, rightorth!, leftnull!, rightnull!, tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, - isposdef, isposdef!, ishermitian, sylvester, rank, cond + isposdef, isposdef!, ishermitian, isisometry, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! export catdomain, catcodomain, absorb, absorb! diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index 82e8600f0..fa6e5e248 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -84,6 +84,8 @@ end safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s)) safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) +isisometry(A::StridedMatrix; kwargs...) = isapprox(A' * A, LinearAlgebra.I, kwargs...) + function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real) iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) m, n = size(A) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 65d6fa83c..cb6c55110 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -268,6 +268,11 @@ function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) return isposdef!(tcopy) end +function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) + t = permute(t, (p₁, p₂); copy=false) + return isisometry(t) +end + function tsvd(t::AbstractTensorMap; kwargs...) tcopy = copy_oftype(t, float(scalartype(t))) return tsvd!(tcopy; kwargs...) @@ -697,3 +702,12 @@ function LinearAlgebra.isposdef!(t::TensorMap) end return true end + +# TODO: tolerances are per-block, not global or weighted - does that matter? +function isisometry(t::AbstractTensorMap; kwargs...) + domain(t) ≾ codomain(t) || return false + for (_, b) in blocks(t) + MatrixAlgebra.isisometry(b; kwargs...) || return false + end + return true +end From 4a3af28a9bfd931df6bcefdad9f6452d49c20e12 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 8 May 2025 16:31:10 -0400 Subject: [PATCH 019/126] Revert isposdef changes --- src/tensors/factorizations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index cb6c55110..5686bb55c 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -694,8 +694,8 @@ function LinearAlgebra.ishermitian(t::TensorMap) end function LinearAlgebra.isposdef!(t::TensorMap) - domain(t) ≅ codomain(t) || - throw(SpaceMismatch("`isposdef` requires domain and codomain to be isomorphic")) + domain(t) == codomain(t) || + throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false for (c, b) in blocks(t) isposdef!(b) || return false From 6d68f3fa1195ebce278bdb688c4fe57d93700f68 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 8 May 2025 16:31:23 -0400 Subject: [PATCH 020/126] preinitialize polar output --- src/tensors/factorizations.jl | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 5686bb55c..3436fe360 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -331,19 +331,22 @@ function leftorth!(t::TensorMap{<:RealOrComplexFloat}; InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:leftorth!) + VC = MatrixAlgebraKit.initialize_output(left_orth!, t) + if alg isa QR - return left_orth!(t; kind=:qr, atol, rtol) + return left_orth!(t, VC; kind=:qr, atol, rtol) elseif alg isa QRpos - return left_orth!(t; kind=:qrpos, atol, rtol) + return left_orth!(t, VC; kind=:qrpos, atol, rtol) elseif alg isa SDD - return left_orth!(t; kind=:svd, atol, rtol) + return left_orth!(t, VC; kind=:svd, atol, rtol) elseif alg isa Polar - return left_orth!(t; kind=:polar, atol, rtol) + return left_orth!(t, VC; kind=:polar, atol, rtol) elseif alg isa SVD kind = :svd if iszero(atol) && iszero(rtol) alg′ = LAPACK_QRIteration() - return left_orth!(t; kind, alg=BlockAlgorithm(alg′, default_blockscheduler(t)), + return left_orth!(t, VC; kind, + alg=BlockAlgorithm(alg′, default_blockscheduler(t)), atol, rtol) else trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol) @@ -351,17 +354,17 @@ function leftorth!(t::TensorMap{<:RealOrComplexFloat}; scheduler = default_blockscheduler(t) alg′ = MatrixAlgebraKit.TruncatedAlgorithm(BlockAlgorithm(svd_alg, scheduler), trunc) - return left_orth!(t; kind, alg=alg′, atol, rtol) + return left_orth!(t, VC; kind, alg=alg′, atol, rtol) end elseif alg isa QL _reverse!(t; dims=2) - Q, R = left_orth!(t; kind=:qr, atol, rtol) + Q, R = left_orth!(t, VC; kind=:qr, atol, rtol) _reverse!(Q; dims=2) _reverse!(R) return Q, R elseif alg isa QLpos _reverse!(t; dims=2) - Q, R = left_orth!(t; kind=:qrpos, atol, rtol) + Q, R = left_orth!(t, VC; kind=:qrpos, atol, rtol) _reverse!(Q; dims=2) _reverse!(R) return Q, R From 26ca9607967bbeb3e45c6a93412e1e1fcc8f2510 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 9 May 2025 12:01:14 -0400 Subject: [PATCH 021/126] Bump to MatrixAlgebraKit v0.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4861dea5c..474c57a10 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.1.1" +MatrixAlgebraKit = "0.2" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" From ed3978531709bbe07bc21d2c1b73a1f1b55bed0d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 9 May 2025 12:03:41 -0400 Subject: [PATCH 022/126] Rework left_orth --- src/tensors/factorizations.jl | 85 +++++++++++++++++---------------- src/tensors/matrixalgebrakit.jl | 83 ++++++++++++++++++++------------ 2 files changed, 96 insertions(+), 72 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 3436fe360..57614e6ef 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -326,51 +326,54 @@ end function leftorth!(t::TensorMap{<:RealOrComplexFloat}; alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(), atol::Real=zero(float(real(scalartype(t)))), - rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) : - eps(real(float(one(scalartype(t))))) * iszero(atol)) + rtol::Real=(alg ∉ (SVD(), SDD())) ? + zero(float(real(scalartype(t)))) : + eps(real(float(one(scalartype(t))))) * + iszero(atol)) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:leftorth!) - - VC = MatrixAlgebraKit.initialize_output(left_orth!, t) - - if alg isa QR - return left_orth!(t, VC; kind=:qr, atol, rtol) - elseif alg isa QRpos - return left_orth!(t, VC; kind=:qrpos, atol, rtol) - elseif alg isa SDD - return left_orth!(t, VC; kind=:svd, atol, rtol) - elseif alg isa Polar - return left_orth!(t, VC; kind=:polar, atol, rtol) - elseif alg isa SVD - kind = :svd - if iszero(atol) && iszero(rtol) - alg′ = LAPACK_QRIteration() - return left_orth!(t, VC; kind, - alg=BlockAlgorithm(alg′, default_blockscheduler(t)), - atol, rtol) - else - trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol) - svd_alg = LAPACK_QRIteration() - scheduler = default_blockscheduler(t) - alg′ = MatrixAlgebraKit.TruncatedAlgorithm(BlockAlgorithm(svd_alg, scheduler), - trunc) - return left_orth!(t, VC; kind, alg=alg′, atol, rtol) - end - elseif alg isa QL - _reverse!(t; dims=2) - Q, R = left_orth!(t, VC; kind=:qr, atol, rtol) - _reverse!(Q; dims=2) - _reverse!(R) - return Q, R - elseif alg isa QLpos - _reverse!(t; dims=2) - Q, R = left_orth!(t, VC; kind=:qrpos, atol, rtol) - _reverse!(Q; dims=2) - _reverse!(R) - return Q, R + if alg == SVD() || alg == SDD() + return _leftorth!(t, alg; atol, rtol) + else + (iszero(atol) && iszero(rtol)) || + throw(ArgumentError("`leftorth!` with nonzero atol or rtol requires SVD or SDD algorithm")) + return _leftorth!(t, alg) end +end - throw(ArgumentError("Algorithm $alg not implemented for leftorth!")) +# this promotes the algorithm to a positional argument for type stability reasons +# since polar has different number of output legs +# TODO: this seems like duplication from MatrixAlgebraKit.left_orth!, but that function +# only has its logic with the output already specified, which breaks for polar +function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD}; atol::Real, + rtol::Real) + alg′ = alg == SVD() ? MatrixAlgebraKit.LAPACK_QRIteration() : + MatrixAlgebraKit.LAPACK_DivideAndConquer() + alg_svd = BlockAlgorithm(alg′, default_blockscheduler(t)) + if iszero(atol) && iszero(rtol) + U, S, Vᴴ = svd_compact!(t, alg_svd) + return U, lmul!(S, Vᴴ) + else + trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol) + alg_svd = MatrixAlgebraKit.select_algorithm(svd_trunc!, t; trunc, + alg=alg_svd) + + U, S, Vᴴ = svd_trunc!(t, alg_svd) + return U, lmul!(S, Vᴴ) + end +end +function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{QR,QRpos}) + return qr_compact!(t; positive=alg == QRpos()) +end +function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{QL,QLpos}) + _reverse!(t; dims=2) + Q, R = qr_compact!(t; positive=alg == QLpos()) + _reverse!(Q; dims=2) + _reverse!(R) + return Q, R +end +function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, ::Polar) + return MatrixAlgebraKit.left_polar!(t) end function leftnull!(t::TensorMap{<:RealOrComplexFloat}; diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index a1f8682c3..a6952a055 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -21,6 +21,18 @@ macro check_eltype(x, y, f=:identity, g=:eltype) return esc(:($g($x) == $f($g($y)) || throw(ArgumentError($msg)))) end +function MatrixAlgebraKit._select_algorithm(_, ::AbstractTensorMap, + alg::MatrixAlgebraKit.AbstractAlgorithm) + return alg +end +function MatrixAlgebraKit._select_algorithm(f, t::AbstractTensorMap, alg::NamedTuple) + return MatrixAlgebraKit.select_algorithm(f, t; alg...) +end + +function _select_truncation(f, ::AbstractTensorMap, + trunc::MatrixAlgebraKit.TruncationStrategy) + return trunc +end # function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) # T = scalartype(t) # return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) @@ -76,7 +88,7 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTens ::MatrixAlgebraKit.AbstractAlgorithm) V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) - U = similar(t, domain(t) ← V_cod) + U = similar(t, codomain(t) ← V_cod) S = similar(t, real(scalartype(t)), V_cod ← V_dom) Vᴴ = similar(t, V_dom ← domain(t)) return U, S, Vᴴ @@ -476,18 +488,19 @@ function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P)) @check_eltype P t # space checks - space(W) == (codomain(t) ← fuse(domain(t))) || + space(W) == space(t) || throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`")) - space(P) == (fuse(domain(t)) ← domain(t)) || + space(P) == (domain(t) ← domain(t)) || throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`")) return nothing end # TODO: do we really not want to fuse the spaces? -function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap) - W = similar(t, codomain(t) ← fuse(domain(t))) - P = similar(t, fuse(domain(t)) ← domain(t)) +function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap, + ::MatrixAlgebraKit.AbstractAlgorithm) + W = similar(t, space(t)) + P = similar(t, domain(t) ← domain(t)) return W, P end @@ -558,40 +571,48 @@ function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTe return C, Vᴴ end -function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC; kwargs...) - MatrixAlgebraKit.check_input(left_orth!, t, VC) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd) - - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) +function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC; + trunc=nothing, + kind=isnothing(trunc) ? + :qr : :svd, + alg_qr=(; positive=true), + alg_polar=(;), + alg_svd=(;)) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(qr_compact!, t)) - return qr_compact!(t, VC, alg) - elseif kind == :qrpos - alg = get(kwargs, :alg, - MatrixAlgebraKit.select_algorithm(qr_compact!, t; positive=true)) - return qr_compact!(t, VC, alg) - elseif kind == :polar - alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(left_polar!, t)) - return left_polar!(t, VC, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, MatrixAlgebraKit.select_algorithm(svd_compact!, t)) + alg_qr′ = MatrixAlgebraKit._select_algorithm(qr_compact!, t, alg_qr) + return qr_compact!(t, VC, alg_qr′) + end + + if kind == :polar + alg_polar′ = MatrixAlgebraKit._select_algorithm(left_polar!, t, alg_polar) + return left_polar!(t, VC, alg_polar′) + end + + if kind == :svd && isnothing(trunc) + alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_compact!, t, alg_svd) V, C = VC S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C)) - U, S, Vᴴ = svd_compact!(t, (V, S, C), alg) + U, S, Vᴴ = svd_compact!(t, (V, S, C), alg_svd′) return U, lmul!(S, Vᴴ) - elseif kind == :svd - alg_svd = MatrixAlgebraKit.select_algorithm(svd_compact!, t) - trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol) - alg = get(kwargs, :alg, MatrixAlgebraKit.TruncatedAlgorithm(alg_svd, trunc)) + end + + if kind == :svd + alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_compact!, t, alg_svd) + alg_svd_trunc = MatrixAlgebraKit.select_algorithm(svd_trunc!, t; trunc, + alg=alg_svd′) V, C = VC S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C)) - U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg) + U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg_svd_trunc) return U, lmul!(S, Vᴴ) + end + + throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) +end + else throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) end From 09e64fa9cdd731332b1e4722e45686a7b822c279 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 9 May 2025 12:04:17 -0400 Subject: [PATCH 023/126] Rework left_null --- src/tensors/factorizations.jl | 40 ++++------ src/tensors/matrixalgebrakit.jl | 130 +++++++++++++++++++++++++++++++- 2 files changed, 144 insertions(+), 26 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 57614e6ef..97c853632 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -383,36 +383,26 @@ function leftnull!(t::TensorMap{<:RealOrComplexFloat}; eps(real(float(one(scalartype(t))))) * iszero(atol)) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:leftnull!) - if !iszero(rtol) - atol = max(atol, rtol * norm(t)) - end - I = sectortype(t) - dims = SectorDict{I,Int}() - # compute QR factorization for each block - V = codomain(t) - if !isempty(blocksectors(V)) - generator = Base.Iterators.map(blocksectors(V)) do c - Nc = MatrixAlgebra.leftnull!(block(t, c), alg, atol) - dims[c] = size(Nc, 2) - return c => Nc + if alg == SVD() || alg == SDD() + kind = :svd + alg_svd = BlockAlgorithm(alg == SVD() ? MatrixAlgebraKit.LAPACK_QRIteration() : + MatrixAlgebraKit.LAPACK_DivideAndConquer(), + default_blockscheduler(t)) + trunc = if iszero(atol) && iszero(rtol) + nothing + else + (; atol, rtol) end - Ndata = SectorDict(generator) + return left_null!(t; kind, alg_svd, trunc) end - # construct new space - S = spacetype(t) - W = S(dims) + (iszero(atol) && iszero(rtol)) || + throw(ArgumentError("`leftnull!` with nonzero atol or rtol requires SVD or SDD algorithm")) - # construct output tensor - T = float(scalartype(t)) - N = similar(t, T, V ← W) - if !isempty(blocksectors(V)) - for (c, Nc) in Ndata - copy!(block(N, c), Nc) - end - end - return N + kind = :qr + alg_qr = (; positive=alg == QRpos()) + return left_null!(t; kind, alg_qr) end function rightorth!(t::TensorMap{<:RealOrComplexFloat}; diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index a6952a055..f836e51be 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -33,6 +33,14 @@ function _select_truncation(f, ::AbstractTensorMap, trunc::MatrixAlgebraKit.TruncationStrategy) return trunc end +function _select_truncation(::typeof(left_null!), ::AbstractTensorMap, trunc::NamedTuple) + return MatrixAlgebraKit.null_truncation_strategy(; trunc...) +end + +function MatrixAlgebraKit.diagview(t::AbstractTensorMap) + return SectorDict(c => MatrixAlgebraKit.diagview(b) for (c, b) in blocks(t)) +end + # function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) # T = scalartype(t) # return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) @@ -103,6 +111,11 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractT return U, S, Vᴴ end +function MatrixAlgebraKit.initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, + alg::MatrixAlgebraKit.AbstractAlgorithm) + return MatrixAlgebraKit.initialize_output(svd_compact!, t, alg) +end + # TODO: svd_vals function MatrixAlgebraKit.svd_full!(t::AbstractTensorMap, (U, S, Vᴴ), @@ -613,8 +626,69 @@ function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC; throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) end +# Nullspace +# --------- +function MatrixAlgebraKit.check_input(::typeof(left_null!), t::AbstractTensorMap, N) + # scalartype checks + @check_eltype N t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = setdiff(fuse(codomain(t)), V_Q) + space(N) == (codomain(t) ← V_N) || + throw(SpaceMismatch("`left_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) + + return nothing +end + +function MatrixAlgebraKit.initialize_output(::typeof(left_null!), t::AbstractTensorMap) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = setdiff(fuse(codomain(t)), V_Q) + N = similar(t, codomain(t) ← V_N) + return N +end + +# TODO: the following functions shouldn't be necessary if the AbstractArray restrictions are +# removed +function MatrixAlgebraKit.left_null(t::AbstractTensorMap; kwargs...) + return left_null!(MatrixAlgebraKit.copy_input(left_null, t); kwargs...) +end +function MatrixAlgebraKit.left_null!(t::AbstractTensorMap; kwargs...) + N = MatrixAlgebraKit.initialize_output(left_null!, t) + return left_null!(t, N; kwargs...) +end + +function MatrixAlgebraKit.left_null!(t::AbstractTensorMap, N; + trunc=nothing, + kind=isnothing(trunc) ? :qr : :svd, + alg_qr=(; positive=true), + alg_svd=(;)) + MatrixAlgebraKit.check_input(left_null!, t, N) + + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for left_null with kind=$kind")) + end + + if kind == :qr + alg_qr′ = MatrixAlgebraKit._select_algorithm(qr_null!, t, alg_qr) + return qr_null!(t, N, alg_qr′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd) + # TODO: refactor into separate function + U, _, _ = svd_full!(t, alg_svd′) + for (c, b) in blocks(N) + bU = block(U, c) + m, n = size(bU) + copy!(b, @view(bU[1:m, (n + 1):m])) + end + return N + elseif kind == :svd + alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd) + U, S, _ = svd_full!(t, alg_svd′) + trunc′ = _select_truncation(left_null!, t, trunc) + return MatrixAlgebraKit.truncate!(left_null!, (U, S), trunc′) else - throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) + throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) end end @@ -643,3 +717,57 @@ function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), return Ũ, S̃, Ṽᴴ end + +function MatrixAlgebraKit.truncate!(::typeof(left_null!), + (U, S)::Tuple{<:AbstractTensorMap, + <:AbstractTensorMap}, + strategy::MatrixAlgebraKit.TruncationStrategy) + extended_S = SectorDict(c => vcat(MatrixAlgebraKit.diagview(b), + zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) + for (c, b) in blocks(S)) + ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) + V_truncated = spacetype(S)(c => length(axes(b, 1)[ind[c]]) for (c, b) in blocks(S)) + Ũ = similar(U, codomain(U) ← V_truncated) + for (c, b) in blocks(Ũ) + copy!(b, @view(block(U, c)[:, ind[c]])) + end + return Ũ +end + +const BlockWiseTruncations = Union{MatrixAlgebraKit.TruncationKeepAbove, + MatrixAlgebraKit.TruncationKeepBelow, + MatrixAlgebraKit.TruncationKeepFiltered} + +# TODO: relative tolerances should be global +function MatrixAlgebraKit.findtruncated(values::SectorDict, strategy::BlockWiseTruncations) + return SectorDict(c => MatrixAlgebraKit.findtruncated(v, strategy) for (c, v) in values) +end +function MatrixAlgebraKit.findtruncated(vals::SectorDict, + strategy::MatrixAlgebraKit.TruncationKeepSorted) + allpairs = mapreduce(vcat, vals) do (c, v) + return map(Base.Fix1(=>, c), axes(v, 1)) + end + by((c, i)) = strategy.sortby(vals[c][i]) + sort!(allpairs; by, strategy.rev) + + howmany = zero(Base.promote_op(dim, valtype(values))) + i = 1 + while i ≤ length(allpairs) + howmany += dim(first(allpairs[i])) + + howmany == strategy.howmany && break + + if howmany > strategy.howmany + i -= 1 + break + end + + i += 1 + end + + ind = SectorDict(c => allpairs[findall(==(c) ∘ first, view(allpairs, 1:i))] + for c in keys(vals)) + filter!(!isempty ∘ last, ind) # TODO: this might not be necessary + + return ind +end From f8447037251386c8f8a2e17bdcc8cdad759e18d4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 9 May 2025 12:04:27 -0400 Subject: [PATCH 024/126] include temporary tests --- test/factorizations.jl | 146 ++++++++++++++++++++++++++++++++++------- 1 file changed, 123 insertions(+), 23 deletions(-) diff --git a/test/factorizations.jl b/test/factorizations.jl index a30f005e3..ca9b510fb 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -1,5 +1,105 @@ +using TestEnv; +TestEnv.activate(); + +using Test +using TestExtras +using Random +using TensorKit +using Combinatorics +using TensorKit: ProductSector, fusiontensor, pentagon_equation, hexagon_equation +using TensorOperations +using Base.Iterators: take, product +# using SUNRepresentations: SUNIrrep +# const SU3Irrep = SUNIrrep{3} +using LinearAlgebra: LinearAlgebra +using Zygote: Zygote +using MatrixAlgebraKit + +const TK = TensorKit + +Random.seed!(1234) + +smallset(::Type{I}) where {I<:Sector} = take(values(I), 5) +function smallset(::Type{ProductSector{Tuple{I1,I2}}}) where {I1,I2} + iter = product(smallset(I1), smallset(I2)) + s = collect(i ⊠ j for (i, j) in iter if dim(i) * dim(j) <= 6) + return length(s) > 6 ? rand(s, 6) : s +end +function smallset(::Type{ProductSector{Tuple{I1,I2,I3}}}) where {I1,I2,I3} + iter = product(smallset(I1), smallset(I2), smallset(I3)) + s = collect(i ⊠ j ⊠ k for (i, j, k) in iter if dim(i) * dim(j) * dim(k) <= 6) + return length(s) > 6 ? rand(s, 6) : s +end +function randsector(::Type{I}) where {I<:Sector} + s = collect(smallset(I)) + a = rand(s) + while a == one(a) # don't use trivial label + a = rand(s) + end + return a +end +function hasfusiontensor(I::Type{<:Sector}) + try + fusiontensor(one(I), one(I), one(I)) + return true + catch e + if e isa MethodError + return false + else + rethrow(e) + end + end +end + +# spaces +Vtr = (ℂ^3, + (ℂ^4)', + ℂ^5, + ℂ^6, + (ℂ^7)') +Vℤ₂ = (ℂ[Z2Irrep](0 => 1, 1 => 1), + ℂ[Z2Irrep](0 => 1, 1 => 2)', + ℂ[Z2Irrep](0 => 3, 1 => 2)', + ℂ[Z2Irrep](0 => 2, 1 => 3), + ℂ[Z2Irrep](0 => 2, 1 => 5)) +Vfℤ₂ = (ℂ[FermionParity](0 => 1, 1 => 1), + ℂ[FermionParity](0 => 1, 1 => 2)', + ℂ[FermionParity](0 => 3, 1 => 2)', + ℂ[FermionParity](0 => 2, 1 => 3), + ℂ[FermionParity](0 => 2, 1 => 5)) +Vℤ₃ = (ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 2), + ℂ[Z3Irrep](0 => 3, 1 => 1, 2 => 1), + ℂ[Z3Irrep](0 => 2, 1 => 2, 2 => 1)', + ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 3), + ℂ[Z3Irrep](0 => 1, 1 => 3, 2 => 3)') +VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), + ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), + ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 3), + ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 3)') +VfU₁ = (ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 2), + ℂ[FermionNumber](0 => 3, 1 => 1, -1 => 1), + ℂ[FermionNumber](0 => 2, 1 => 2, -1 => 1)', + ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 3), + ℂ[FermionNumber](0 => 1, 1 => 3, -1 => 3)') +VCU₁ = (ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 2, 1 => 1), + ℂ[CU1Irrep]((0, 0) => 3, (0, 1) => 0, 1 => 1), + ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 0, 1 => 2)', + ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 2, 1 => 1), + ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 1, 1 => 2)') +VSU₂ = (ℂ[SU2Irrep](0 => 3, 1 // 2 => 1), + ℂ[SU2Irrep](0 => 2, 1 => 1), + ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', + ℂ[SU2Irrep](0 => 2, 1 // 2 => 2), + ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') +VfSU₂ = (ℂ[FermionSpin](0 => 3, 1 // 2 => 1), + ℂ[FermionSpin](0 => 2, 1 => 1), + ℂ[FermionSpin](1 // 2 => 1, 1 => 1)', + ℂ[FermionSpin](0 => 2, 1 // 2 => 2), + ℂ[FermionSpin](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) V1, V2, V3, V4, V5 = V + @assert V3 * V4 * V2 ≿ V1' * V5' # necessary for leftorth tests @assert V3 * V4 ≾ V1' * V2' * V5' # necessary for rightorth tests end @@ -21,33 +121,33 @@ catch (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) end -@timedtestset "Factorizatios with symmetry: $(sectortype(first(V)))" for V in spacelist + +function test_leftorth(t, p, alg) + Q, R = @inferred leftorth(t, p; alg) + @test Q * R ≈ permute(t, p) + @test isisometry(Q) + if alg isa Polar + @test isposdef(R) + @test domain(R) == codomain(R) == domain(permute(space(t), p)) + end +end +function test_leftnull(t, p, alg) + N = @inferred leftnull(t, p; alg) + @test isisometry(N) + @test norm(N' * permute(t, p)) ≈ 0 atol= 100 * eps(norm(t)) +end + +# @timedtestset "Factorizations with symmetry: $(sectortype(first(V)))" for V in spacelist + V = collect(spacelist)[2] V1, V2, V3, V4, V5 = V W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 for T in (Float32, ComplexF64), adj in (false, true) - t = adj ? rand(T, W)' : rand(T, W) - @testset "leftorth with $alg" for alg in - (TensorKit.QR(), TensorKit.QRpos(), - TensorKit.QL(), TensorKit.QLpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - Q, R = @constinferred leftorth(t, ((3, 4, 2), (1, 5)); alg=alg) - QdQ = Q' * Q - @test QdQ ≈ one(QdQ) - @test Q * R ≈ permute(t, ((3, 4, 2), (1, 5))) - if alg isa Polar - @test isposdef(R) - @test domain(R) == codomain(R) == space(t, 1)' ⊗ space(t, 5)' - end + t = adj ? rand(T, W)' : rand(T, W); + @testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QRpos(), TensorKit.QL(), TensorKit.QLpos(), TensorKit.Polar(), TensorKit.SVD(), TensorKit.SDD()) + test_leftorth(t, ((3, 4, 2), (1, 5)), alg) end - @testset "leftnull with $alg" for alg in - (TensorKit.QR(), TensorKit.SVD(), - TensorKit.SDD()) - N = @constinferred leftnull(t, ((3, 4, 2), (1, 5)); alg=alg) - NdN = N' * N - @test NdN ≈ one(NdN) - @test norm(N' * permute(t, ((3, 4, 2), (1, 5)))) < - 100 * eps(norm(t)) + @testset "leftnull with $alg" for alg in (TensorKit.QR(), TensorKit.SVD(), TensorKit.SDD()) + test_leftnull(t, ((3, 4, 2), (1, 5)), alg) end @testset "rightorth with $alg" for alg in (TensorKit.RQ(), TensorKit.RQpos(), From ae8dd7200d7de95522a42f8ae10cdae20549190e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 11:58:52 -0400 Subject: [PATCH 025/126] Change `Base.setdiff` for `ominus` --- docs/src/lib/spaces.md | 1 + src/spaces/cartesianspace.jl | 11 ++++++----- src/spaces/complexspace.jl | 13 +++++++------ src/spaces/gradedspace.jl | 12 ++++++------ src/spaces/vectorspaces.jl | 18 +++++++++++------- src/tensors/matrixalgebrakit.jl | 18 +++++++++--------- 6 files changed, 40 insertions(+), 33 deletions(-) diff --git a/docs/src/lib/spaces.md b/docs/src/lib/spaces.md index 83350156d..205301d83 100644 --- a/docs/src/lib/spaces.md +++ b/docs/src/lib/spaces.md @@ -90,6 +90,7 @@ dual conj flip ⊕ +⊖ zero(::ElementarySpace) oneunit supremum diff --git a/src/spaces/cartesianspace.jl b/src/spaces/cartesianspace.jl index fd38e0c1e..fe12f3dc6 100644 --- a/src/spaces/cartesianspace.jl +++ b/src/spaces/cartesianspace.jl @@ -49,16 +49,17 @@ sectortype(::Type{CartesianSpace}) = Trivial Base.oneunit(::Type{CartesianSpace}) = CartesianSpace(1) Base.zero(::Type{CartesianSpace}) = CartesianSpace(0) + ⊕(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d + V₂.d) +function ⊖(V::CartesianSpace, W::CartesianSpace) + V ≿ W || throw(ArgumentError("$(W) is not a subspace of $(V)")) + return CartesianSpace(dim(V) - dim(W)) +end + fuse(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d * V₂.d) flip(V::CartesianSpace) = V infimum(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(min(V₁.d, V₂.d)) supremum(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(max(V₁.d, V₂.d)) -function Base.setdiff(V::CartesianSpace, W::CartesianSpace) - V ≿ W || throw(ArgumentError("$(W) is not a subspace of $(V)")) - return CartesianSpace(dim(V) - dim(W)) -end - Base.show(io::IO, V::CartesianSpace) = print(io, "ℝ^$(V.d)") diff --git a/src/spaces/complexspace.jl b/src/spaces/complexspace.jl index 51a3056e9..1031db614 100644 --- a/src/spaces/complexspace.jl +++ b/src/spaces/complexspace.jl @@ -50,11 +50,18 @@ Base.conj(V::ComplexSpace) = ComplexSpace(dim(V), !isdual(V)) Base.oneunit(::Type{ComplexSpace}) = ComplexSpace(1) Base.zero(::Type{ComplexSpace}) = ComplexSpace(0) + function ⊕(V₁::ComplexSpace, V₂::ComplexSpace) return isdual(V₁) == isdual(V₂) ? ComplexSpace(dim(V₁) + dim(V₂), isdual(V₁)) : throw(SpaceMismatch("Direct sum of a vector space and its dual does not exist")) end +function ⊖(V::ComplexSpace, W::ComplexSpace) + (V ≿ W && isdual(V) == isdual(W)) || + throw(ArgumentError("$(W) is not a subspace of $(V)")) + return ComplexSpace(dim(V) - dim(W), isdual(V)) +end + fuse(V₁::ComplexSpace, V₂::ComplexSpace) = ComplexSpace(V₁.d * V₂.d) flip(V::ComplexSpace) = dual(V) @@ -69,10 +76,4 @@ function supremum(V₁::ComplexSpace, V₂::ComplexSpace) throw(SpaceMismatch("Supremum of space and dual space does not exist")) end -function Base.setdiff(V::ComplexSpace, W::ComplexSpace) - (V ≿ W && isdual(V) == isdual(W)) || - throw(ArgumentError("$(W) is not a subspace of $(V)")) - return ComplexSpace(dim(V) - dim(W), isdual(V)) -end - Base.show(io::IO, V::ComplexSpace) = print(io, isdual(V) ? "(ℂ^$(V.d))'" : "ℂ^$(V.d)") diff --git a/src/spaces/gradedspace.jl b/src/spaces/gradedspace.jl index e4a016602..ddc08046d 100644 --- a/src/spaces/gradedspace.jl +++ b/src/spaces/gradedspace.jl @@ -149,6 +149,12 @@ function ⊕(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector} return typeof(V₁)(dims; dual=dual1) end +function ⊖(V::GradedSpace{I}, W::GradedSpace{I}) where {I<:Sector} + V ≿ W && isdual(V) == isdual(W) || + throw(SpaceMismatch("$(W) is not a subspace of $(V)")) + return typeof(V)(c => dim(V, c) - dim(W, c) for c in sectors(V)) +end + function flip(V::GradedSpace{I}) where {I<:Sector} if isdual(V) typeof(V)(c => dim(V, c) for c in sectors(V)) @@ -183,12 +189,6 @@ function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector} for c in union(sectors(V₁), sectors(V₂)); dual=Visdual) end -function Base.setdiff(V::GradedSpace{I}, W::GradedSpace{I}) where {I<:Sector} - V ≿ W && isdual(V) == isdual(W) || - throw(SpaceMismatch("$(W) is not a subspace of $(V)")) - return typeof(V)(c => dim(V, c) - dim(W, c) for c in sectors(V)) -end - function Base.show(io::IO, V::GradedSpace{I}) where {I<:Sector} print(io, type_repr(typeof(V)), "(") separator = "" diff --git a/src/spaces/vectorspaces.jl b/src/spaces/vectorspaces.jl index 3b6e6f08e..3b903e19d 100644 --- a/src/spaces/vectorspaces.jl +++ b/src/spaces/vectorspaces.jl @@ -150,6 +150,17 @@ function ⊕ end ⊕(V::Vararg{ElementarySpace}) = foldl(⊕, V) const oplus = ⊕ +""" + ⊖(V::ElementarySpace, W::ElementarySpace) -> X::ElementarySpace + ominus(V::ElementarySpace, W::ElementarySpace) -> X::ElementarySpace + +Return the set difference of two elementary spaces, i.e. an instance `X::ElementarySpace` +such that `V = W ⊕ X`. +""" +⊖(V₁::S, V₂::S) where {S<:ElementarySpace} +⊖(V₁::VectorSpace, V₂::VectorSpace) = ⊖(promote(V₁, V₂)...) +const ominus = ⊖ + """ ⊗(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S @@ -406,10 +417,3 @@ function supremum(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} return supremum(supremum(V₁, V₂), V₃...) end -""" - setdiff(V::ElementarySpace, W::ElementarySpace) - -Return the set difference of two elementary spaces, i.e. an instance `X::ElementarySpace` -such that `V = W ⊕ X`. -""" -Base.setdiff(V₁::S, V₂::S) where {S<:ElementarySpace} diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index f836e51be..f775631c5 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -296,9 +296,9 @@ function MatrixAlgebraKit.check_input(::typeof(qr_null!), t::AbstractTensorMap, # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = setdiff(fuse(codomain(t)), V_Q) + V_N = ⊖(fuse(codomain(t)), V_Q) space(N) == (codomain(t) ← V_N) || - throw(SpaceMismatch("`qr_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) + throw(SpaceMismatch("`qr_null!(t, N)` requires `space(N) == (codomain(t) ← ⊖(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) return nothing end @@ -322,7 +322,7 @@ end function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), t::AbstractTensorMap, ::MatrixAlgebraKit.AbstractAlgorithm) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = setdiff(fuse(codomain(t)), V_Q) + V_N = ⊖(fuse(codomain(t)), V_Q) N = similar(t, codomain(t) ← V_N) return N end @@ -414,9 +414,9 @@ function MatrixAlgebraKit.check_input(::typeof(lq_null!), t::AbstractTensorMap, # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = setdiff(fuse(domain(t)), V_Q) + V_N = ⊖(fuse(domain(t)), V_Q) space(N) == (V_N ← domain(t)) || - throw(SpaceMismatch("`lq_null!(t, N)` requires `space(N) == setdiff(fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`")) + throw(SpaceMismatch("`lq_null!(t, N)` requires `space(N) == ⊖(fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`")) return nothing end @@ -440,7 +440,7 @@ end function MatrixAlgebraKit.initialize_output(::typeof(lq_null!), t::AbstractTensorMap, ::MatrixAlgebraKit.AbstractAlgorithm) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = setdiff(fuse(domain(t)), V_Q) + V_N = ⊖(fuse(domain(t)), V_Q) N = similar(t, V_N ← domain(t)) return N end @@ -634,16 +634,16 @@ function MatrixAlgebraKit.check_input(::typeof(left_null!), t::AbstractTensorMap # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = setdiff(fuse(codomain(t)), V_Q) + V_N = ⊖(fuse(codomain(t)), V_Q) space(N) == (codomain(t) ← V_N) || - throw(SpaceMismatch("`left_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) + throw(SpaceMismatch("`left_null!(t, N)` requires `space(N) == (codomain(t) ← ⊖(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) return nothing end function MatrixAlgebraKit.initialize_output(::typeof(left_null!), t::AbstractTensorMap) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = setdiff(fuse(codomain(t)), V_Q) + V_N = ⊖(fuse(codomain(t)), V_Q) N = similar(t, codomain(t) ← V_N) return N end From ee232e6337e3ba07b9edef1d772c15a2dc4550b6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 10 Jun 2025 21:44:23 -0400 Subject: [PATCH 026/126] change blockscheduler to type domain --- src/tensors/backends.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tensors/backends.jl b/src/tensors/backends.jl index 40cbd1ba5..1fc970e72 100644 --- a/src/tensors/backends.jl +++ b/src/tensors/backends.jl @@ -28,7 +28,8 @@ Run `f` in a scope where the `blockscheduler` is determined by `scheduler' and ` end # TODO: disable for trivial symmetry or small tensors? -default_blockscheduler(t::AbstractTensorMap) = blockscheduler[] +default_blockscheduler(t::AbstractTensorMap) = default_blockscheduler(typeof(t)) +default_blockscheduler(::Type{T}) where {T<:AbstractTensorMap} = blockscheduler[] # MatrixAlgebraKit # ---------------- From b2597547ff5a789ce011ed1d4620eabda9be1128 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 10 Jun 2025 21:44:36 -0400 Subject: [PATCH 027/126] make block iterator loop over union of sectors --- src/tensors/blockiterator.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index 008d9fdf2..984facd1b 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -33,8 +33,9 @@ end ``` """ function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler=nothing) - foreach(blocks(t)) do (c, b) - return f(c, (b, map(Base.Fix2(block, c), ts)...)) + allsectors = union(blocksectors(t), blocksectors.(ts)...) + foreach(allsectors) do c + return f(c, map(Base.Fix2(block, c), (t, ts...))) end return nothing end From cda21d3d3295ea8e880dc72e8f61715dedc08fb7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 10 Jun 2025 21:45:01 -0400 Subject: [PATCH 028/126] refactor `left_orth` for new matrixalgebrakit version --- src/tensors/factorizations.jl | 78 +++++++++++++++++------------------ 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 97c853632..4c77814b4 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -324,57 +324,55 @@ function _reverse!(t::AbstractTensorMap; dims=:) end function leftorth!(t::TensorMap{<:RealOrComplexFloat}; - alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(), - atol::Real=zero(float(real(scalartype(t)))), - rtol::Real=(alg ∉ (SVD(), SDD())) ? - zero(float(real(scalartype(t)))) : - eps(real(float(one(scalartype(t))))) * - iszero(atol)) + alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar,Nothing}=nothing, + kwargs...) + # atol::Real=zero(float(real(scalartype(t)))), + # rtol::Real=(alg ∉ (SVD(), SDD())) ? + # zero(float(real(scalartype(t)))) : + # eps(real(float(one(scalartype(t))))) * + # iszero(atol)) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:leftorth!) - if alg == SVD() || alg == SDD() - return _leftorth!(t, alg; atol, rtol) - else - (iszero(atol) && iszero(rtol)) || - throw(ArgumentError("`leftorth!` with nonzero atol or rtol requires SVD or SDD algorithm")) - return _leftorth!(t, alg) - end + return _leftorth!(t, alg; kwargs...) + + # if alg == SVD() || alg == SDD() + # return _leftorth!(t, alg; atol, rtol) + # else + # (iszero(atol) && iszero(rtol)) || + # throw(ArgumentError("`leftorth!` with nonzero atol or rtol requires SVD or SDD algorithm")) + # return _leftorth!(t, alg) + # end end # this promotes the algorithm to a positional argument for type stability reasons # since polar has different number of output legs # TODO: this seems like duplication from MatrixAlgebraKit.left_orth!, but that function # only has its logic with the output already specified, which breaks for polar -function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD}; atol::Real, - rtol::Real) - alg′ = alg == SVD() ? MatrixAlgebraKit.LAPACK_QRIteration() : - MatrixAlgebraKit.LAPACK_DivideAndConquer() - alg_svd = BlockAlgorithm(alg′, default_blockscheduler(t)) - if iszero(atol) && iszero(rtol) - U, S, Vᴴ = svd_compact!(t, alg_svd) - return U, lmul!(S, Vᴴ) +function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg; kwargs...) + trunc = isempty(kwargs) ? nothing : (; kwargs...) + if isnothing(alg) + return left_orth!(t; trunc) + elseif alg == SVD() + return left_orth!(t; kind=:svd, alg_svd=:LAPACK_QRIteration, trunc) + elseif alg == SDD() + return left_orth!(t; kind=:svd, alg_svd=:LAPACK_DivideAndConquer, trunc) + elseif alg == QR() + return left_orth!(t; kind=:qr, alg_qr=(; positive=false), trunc) + elseif alg == QRpos() + return left_orth!(t; kind=:qr, alg_qr=(; positive=true), trunc) + elseif alg == QL() || alg == QLpos() + _reverse!(t; dims=2) + Q, R = left_orth!(t; kind=:qr, alg_qr=(; positive=alg == QLpos()), trunc) + _reverse!(Q; dims=2) + _reverse!(R) + return Q, R + elseif alg == Polar() + return left_orth!(t; kind=:polar, trunc) else - trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol) - alg_svd = MatrixAlgebraKit.select_algorithm(svd_trunc!, t; trunc, - alg=alg_svd) - - U, S, Vᴴ = svd_trunc!(t, alg_svd) - return U, lmul!(S, Vᴴ) + throw(ArgumentError(lazy"Invalid algorithm: $alg")) end end -function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{QR,QRpos}) - return qr_compact!(t; positive=alg == QRpos()) -end -function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{QL,QLpos}) - _reverse!(t; dims=2) - Q, R = qr_compact!(t; positive=alg == QLpos()) - _reverse!(Q; dims=2) - _reverse!(R) - return Q, R -end -function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, ::Polar) - return MatrixAlgebraKit.left_polar!(t) -end + function leftnull!(t::TensorMap{<:RealOrComplexFloat}; alg::Union{QR,QRpos,SVD,SDD}=QRpos(), From 89ea628a68b23276c1e5b3d4bdd2467866048945 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 10 Jun 2025 21:51:21 -0400 Subject: [PATCH 029/126] Bunch of simplifications for new matrixalgebrakit versions --- src/tensors/matrixalgebrakit.jl | 260 +++++++++++++------------------- 1 file changed, 101 insertions(+), 159 deletions(-) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl index f775631c5..4481267df 100644 --- a/src/tensors/matrixalgebrakit.jl +++ b/src/tensors/matrixalgebrakit.jl @@ -1,3 +1,11 @@ +# convenience to set default +macro check_space(x, V) + return esc(:($MatrixAlgebraKit.@check_size($x, $V, $space))) +end +macro check_scalar(x, y, op=:identity, eltype=:scalartype) + return esc(:($MatrixAlgebraKit.@check_scalar($x, $y, $op, $eltype))) +end + # Generic # ------- for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, :svd_full, @@ -7,6 +15,31 @@ for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, T = factorisation_scalartype($f, t) return copy_oftype(t, T) end + f! = Symbol(f, :!) + @eval function MatrixAlgebraKit.select_algorithm(::typeof($f!), t::AbstractTensorMap, + alg::Alg=nothing; + kwargs...) where {Alg} + return MatrixAlgebraKit.select_algorithm($f!, typeof(t), alg; kwargs...) + end + @eval function MatrixAlgebraKit.select_algorithm(::typeof($f!), ::Type{T}, + alg::Alg=nothing; + scheduler=default_blockscheduler(T), + kwargs...) where {T<:AbstractTensorMap, + Alg} + mat_alg = MatrixAlgebraKit.select_algorithm($f!, blocktype(T), alg; kwargs...) + return BlockAlgorithm(mat_alg, scheduler) + end +end + +for f in (:qr, :lq, :svd, :eig, :eigh, :polar) + default_f_algorithm = Symbol(:default_, f, :_algorithm) + @eval function MatrixAlgebraKit.$default_f_algorithm(::Type{T}; + scheduler=default_blockscheduler(T), + kwargs...) where {T<:AbstractTensorMap} + return BlockAlgorithm(MatrixAlgebraKit.$default_f_algorithm(blocktype(T); + kwargs...), + scheduler) + end end # TODO: move to MatrixAlgebraKit? @@ -21,14 +54,6 @@ macro check_eltype(x, y, f=:identity, g=:eltype) return esc(:($g($x) == $f($g($y)) || throw(ArgumentError($msg)))) end -function MatrixAlgebraKit._select_algorithm(_, ::AbstractTensorMap, - alg::MatrixAlgebraKit.AbstractAlgorithm) - return alg -end -function MatrixAlgebraKit._select_algorithm(f, t::AbstractTensorMap, alg::NamedTuple) - return MatrixAlgebraKit.select_algorithm(f, t; alg...) -end - function _select_truncation(f, ::AbstractTensorMap, trunc::MatrixAlgebraKit.TruncationStrategy) return trunc @@ -41,11 +66,6 @@ function MatrixAlgebraKit.diagview(t::AbstractTensorMap) return SectorDict(c => MatrixAlgebraKit.diagview(b) for (c, b) in blocks(t)) end -# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) -# T = scalartype(t) -# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) -# end - # Singular value decomposition # ---------------------------- const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} @@ -54,19 +74,16 @@ const _T_USVᴴ_diag = Tuple{<:AbstractTensorMap,<:DiagonalTensorMap,<:AbstractT function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ) # scalartype checks - @check_eltype U t - @check_eltype S t real - @check_eltype Vᴴ t + @check_scalar U t + @check_scalar S t real + @check_scalar Vᴴ t # space checks V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) - space(U) == (codomain(t) ← V_cod) || - throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← fuse(domain(t)))`")) - space(S) == (V_cod ← V_dom) || - throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(S) == (fuse(codomain(t)) ← fuse(domain(t))`")) - space(Vᴴ) == (V_dom ← domain(t)) || - throw(SpaceMismatch("`svd_full!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (fuse(domain(t)) ← domain(t))`")) + @check_space(U, codomain(t) ← V_cod) + @check_space(S, V_cod ← V_dom) + @check_space(Vᴴ, V_dom ← domain(t)) return nothing end @@ -80,12 +97,9 @@ function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorM # space checks V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) - space(U) == (codomain(t) ← V_cod) || - throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires `space(U) == (codomain(t) ← infimum(fuse(domain(t)), fuse(codomain(t)))`")) - space(S) == (V_cod ← V_dom) || - throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires diagonal `S` with `domain(S) == (infimum(fuse(codomain(t)), fuse(domain(t)))`")) - space(Vᴴ) == (V_dom ← domain(t)) || - throw(SpaceMismatch("`svd_compact!(t, (U, S, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(domain(t)), fuse(codomain(t))) ← domain(t))`")) + @check_space(U, codomain(t) ← V_cod) + @check_space(S, V_cod ← V_dom) + @check_space(Vᴴ, V_dom ← domain(t)) return nothing end @@ -124,7 +138,7 @@ function MatrixAlgebraKit.svd_full!(t::AbstractTensorMap, (U, S, Vᴴ), foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ) if isempty(b) # TODO: remove once MatrixAlgebraKit supports empty matrices - one!(length(u) > 0 ? u : vᴴ) + MatrixAlgebraKit.one!(length(u) > 0 ? u : vᴴ) zerovector!(s) else u′, s′, vᴴ′ = MatrixAlgebraKit.svd_full!(b, (u, s, vᴴ), alg.alg) @@ -161,12 +175,6 @@ function MatrixAlgebraKit.svd_trunc!(t::AbstractTensorMap, USVᴴ, return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ′, alg.trunc) end -function MatrixAlgebraKit.default_svd_algorithm(t::AbstractTensorMap{<:BlasFloat}; - scheduler=default_blockscheduler(t), - kwargs...) - return BlockAlgorithm(LAPACK_DivideAndConquer(; kwargs...), scheduler) -end - # Eigenvalue decomposition # ------------------------ const _T_DV = Tuple{<:DiagonalTensorMap,<:AbstractTensorMap} @@ -176,15 +184,13 @@ function MatrixAlgebraKit.check_input(::typeof(eigh_full!), t::AbstractTensorMap throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) # scalartype checks - @check_eltype D t real - @check_eltype V t + @check_scalar D t real + @check_scalar V t # space checks V_D = fuse(domain(t)) - V_D == space(D, 1) || - throw(SpaceMismatch("`eigh_full!(t, (D, V))` requires diagonal `D` with `domain(D) == fuse(domain(t))`")) - space(V) == (codomain(t) ← V_D) || - throw(SpaceMismatch("`eigh_full!(t, (D, V))` requires `space(V) == (codomain(t) ← fuse(domain(t)))`")) + @check_space(D, V_D ← V_D) + @check_space(V, codomain(t) ← V_D) return nothing end @@ -195,15 +201,13 @@ function MatrixAlgebraKit.check_input(::typeof(eig_full!), t::AbstractTensorMap, throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) # scalartype checks - @check_eltype D t complex - @check_eltype V t complex + @check_scalar D t complex + @check_scalar V t complex # space checks V_D = fuse(domain(t)) - V_D == space(D, 1) || - throw(SpaceMismatch("`eig_full!(t, (D, V))` requires diagonal `D` with `domain(D) == fuse(domain(t))`")) - space(V) == (codomain(t) ← V_D) || - throw(SpaceMismatch("`eig_full!(t, (D, V))` requires `space(V) == (codomain(t) ← fuse(domain(t)))`")) + @check_space(D, V_D ← V_D) + @check_space(V, codomain(t) ← V_D) return nothing end @@ -243,48 +247,32 @@ for f in (:eigh_full!, :eig_full!) end end -function MatrixAlgebraKit.default_eig_algorithm(t::AbstractTensorMap{<:BlasFloat}; - scheduler=default_blockscheduler(t), - kwargs...) - return BlockAlgorithm(LAPACK_Expert(; kwargs...), scheduler) -end -function MatrixAlgebraKit.default_eigh_algorithm(t::AbstractTensorMap{<:BlasFloat}; - scheduler=default_blockscheduler(t), - kwargs...) - return BlockAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...), - scheduler) -end - # QR decomposition # ---------------- function MatrixAlgebraKit.check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::Tuple{<:AbstractTensorMap,<:AbstractTensorMap}) # scalartype checks - @check_eltype Q t - @check_eltype R t + @check_scalar Q t + @check_scalar R t # space checks V_Q = fuse(codomain(t)) - space(Q) == (codomain(t) ← V_Q) || - throw(SpaceMismatch("`qr_full!(t, (Q, R))` requires `space(Q) == (codomain(t) ← fuse(codomain(t)))`")) - space(R) == (V_Q ← domain(t)) || - throw(SpaceMismatch("`qr_full!(t, (Q, R))` requires `space(R) == (fuse(codomain(t)) ← domain(t)`")) + @check_space(Q, codomain(t) ← V_Q) + @check_space(R, V_Q ← domain(t)) return nothing end function MatrixAlgebraKit.check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)) # scalartype checks - @check_eltype Q t - @check_eltype R t + @check_scalar Q t + @check_scalar R t # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - space(Q) == (codomain(t) ← V_Q) || - throw(SpaceMismatch("`qr_compact!(t, (Q, R))` requires `space(Q) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`")) - space(R) == (V_Q ← domain(t)) || - throw(SpaceMismatch("`qr_compact!(t, (Q, R))` requires `space(R) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`")) + @check_space(Q, codomain(t) ← V_Q) + @check_space(R, V_Q ← domain(t)) return nothing end @@ -292,13 +280,12 @@ end function MatrixAlgebraKit.check_input(::typeof(qr_null!), t::AbstractTensorMap, N::AbstractTensorMap) # scalartype checks - @check_eltype N t + @check_scalar N t # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) V_N = ⊖(fuse(codomain(t)), V_Q) - space(N) == (codomain(t) ← V_N) || - throw(SpaceMismatch("`qr_null!(t, N)` requires `space(N) == (codomain(t) ← ⊖(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) + @check_space(N, codomain(t) ← V_N) return nothing end @@ -370,11 +357,6 @@ function MatrixAlgebraKit.qr_null!(t::AbstractTensorMap, N, alg::BlockAlgorithm) return N end -function MatrixAlgebraKit.default_qr_algorithm(t::AbstractTensorMap{<:BlasFloat}; - scheduler=default_blockscheduler(t), - kwargs...) - return BlockAlgorithm(LAPACK_HouseholderQR(; kwargs...), scheduler) -end # LQ decomposition # ---------------- @@ -395,28 +377,25 @@ end function MatrixAlgebraKit.check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)) # scalartype checks - @check_eltype L t - @check_eltype Q t + @check_scalar L t + @check_scalar Q t # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - space(L) == (codomain(t) ← V_Q) || - throw(SpaceMismatch("`lq_compact!(t, (L, Q))` requires `space(L) == infimum(fuse(codomain(t)), fuse(domain(t)))`")) - space(Q) == (V_Q ← domain(t)) || - throw(SpaceMismatch("`lq_compact!(t, (L, Q))` requires `space(Q) == infimum(fuse(codomain(t)), fuse(domain(t)))`")) + @check_space(L, codomain(t) ← V_Q) + @check_space(Q, V_Q ← domain(t)) return nothing end function MatrixAlgebraKit.check_input(::typeof(lq_null!), t::AbstractTensorMap, N) # scalartype checks - @check_eltype N t + @check_scalar N t # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) V_N = ⊖(fuse(domain(t)), V_Q) - space(N) == (V_N ← domain(t)) || - throw(SpaceMismatch("`lq_null!(t, N)` requires `space(N) == ⊖(fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`")) + @check_space(N, V_N ← domain(t)) return nothing end @@ -497,21 +476,19 @@ function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P)) throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) # scalartype checks - @check_eltype W t - @check_eltype P t + @check_scalar W t + @check_scalar P t # space checks - space(W) == space(t) || - throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(W) == (codomain(t) ← domain(t))`")) - space(P) == (domain(t) ← domain(t)) || - throw(SpaceMismatch("`left_polar!(t, (W, P))` requires `space(P) == (domain(t) ← domain(t))`")) + @check_space(W, space(t)) + @check_space(P, domain(t) ← domain(t)) return nothing end # TODO: do we really not want to fuse the spaces? function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) + ::BlockAlgorithm) W = similar(t, space(t)) P = similar(t, domain(t) ← domain(t)) return W, P @@ -531,41 +508,46 @@ function MatrixAlgebraKit.left_polar!(t::AbstractTensorMap, WP, alg::BlockAlgori return WP end -function MatrixAlgebraKit.default_polar_algorithm(t::AbstractTensorMap{<:BlasFloat}; - scheduler=default_blockscheduler(t), - kwargs...) - return BlockAlgorithm(PolarViaSVD(LAPACK_DivideAndConquer(; kwargs...)), - scheduler) +# Trick to relax the checks of "square" if coming from left_orth +function MatrixAlgebraKit.left_orth_polar!(t::AbstractTensorMap, VC, alg) + alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg) + return MatrixAlgebraKit.left_orth_polar!(t, VC, alg′) +end +function MatrixAlgebraKit.left_orth_polar!(t::AbstractTensorMap, WP, alg::BlockAlgorithm) + foreachblock(t, WP...; alg.scheduler) do _, (b, w, p) + w′, p′ = left_polar!(b, (w, p), alg.alg) + # deal with the case where the output is not the same as the input + w === w′ || copyto!(w, w′) + p === p′ || copyto!(p, p′) + return nothing + end + return WP end # Orthogonalization # ----------------- function MatrixAlgebraKit.check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)) # scalartype checks - @check_eltype V t - isnothing(C) || @check_eltype C t + @check_scalar V t + isnothing(C) || @check_scalar C t # space checks V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - space(V) == (codomain(t) ← V_C) || - throw(SpaceMismatch("`left_orth!(t, (V, C))` requires `space(V) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t))))`")) - isnothing(C) || space(C) == (V_C ← domain(t)) || - throw(SpaceMismatch("`left_orth!(t, (V, C))` requires `space(C) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`")) + @check_space(V, codomain(t) ← V_C) + isnothing(C) || @check_space(CV_C ← domain(t)) return nothing end function MatrixAlgebraKit.check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)) # scalartype checks - isnothing(C) || @check_eltype C t - @check_eltype Vᴴ t + isnothing(C) || @check_scalar C t + @check_scalar Vᴴ t # space checks V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - isnothing(C) || space(C) == (codomain(t) ← V_C) || - throw(SpaceMismatch("`right_orth!(t, (C, Vᴴ))` requires `space(C) == (codomain(t) ← infimum(fuse(codomain(t)), fuse(domain(t)))`")) - space(Vᴴ) == (V_dom ← domain(t)) || - throw(SpaceMismatch("`right_orth!(t, (C, Vᴴ))` requires `space(Vᴴ) == (infimum(fuse(codomain(t)), fuse(domain(t))) ← domain(t))`")) + isnothing(C) || @check_space(C, codomain(t) ← V_C) + @check_space(Vᴴ, V_dom ← domain(t)) return nothing end @@ -584,59 +566,16 @@ function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTe return C, Vᴴ end -function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC; - trunc=nothing, - kind=isnothing(trunc) ? - :qr : :svd, - alg_qr=(; positive=true), - alg_polar=(;), - alg_svd=(;)) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) - end - - if kind == :qr - alg_qr′ = MatrixAlgebraKit._select_algorithm(qr_compact!, t, alg_qr) - return qr_compact!(t, VC, alg_qr′) - end - - if kind == :polar - alg_polar′ = MatrixAlgebraKit._select_algorithm(left_polar!, t, alg_polar) - return left_polar!(t, VC, alg_polar′) - end - - if kind == :svd && isnothing(trunc) - alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_compact!, t, alg_svd) - V, C = VC - S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C)) - U, S, Vᴴ = svd_compact!(t, (V, S, C), alg_svd′) - return U, lmul!(S, Vᴴ) - end - - if kind == :svd - alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_compact!, t, alg_svd) - alg_svd_trunc = MatrixAlgebraKit.select_algorithm(svd_trunc!, t; trunc, - alg=alg_svd′) - V, C = VC - S = DiagonalTensorMap{real(scalartype(t))}(undef, domain(V) ← codomain(C)) - U, S, Vᴴ = svd_trunc!(t, (V, S, C), alg_svd_trunc) - return U, lmul!(S, Vᴴ) - end - - throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) -end - # Nullspace # --------- function MatrixAlgebraKit.check_input(::typeof(left_null!), t::AbstractTensorMap, N) # scalartype checks - @check_eltype N t + @check_scalar N t # space checks V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) V_N = ⊖(fuse(codomain(t)), V_Q) - space(N) == (codomain(t) ← V_N) || - throw(SpaceMismatch("`left_null!(t, N)` requires `space(N) == (codomain(t) ← ⊖(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`")) + @check_space(N, codomain(t) ← V_N) return nothing end @@ -670,9 +609,11 @@ function MatrixAlgebraKit.left_null!(t::AbstractTensorMap, N; end if kind == :qr + @info "qr" alg_qr′ = MatrixAlgebraKit._select_algorithm(qr_null!, t, alg_qr) return qr_null!(t, N, alg_qr′) elseif kind == :svd && isnothing(trunc) + @info "svd" alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd) # TODO: refactor into separate function U, _, _ = svd_full!(t, alg_svd′) @@ -683,9 +624,10 @@ function MatrixAlgebraKit.left_null!(t::AbstractTensorMap, N; end return N elseif kind == :svd + @info "svd2" alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd) - U, S, _ = svd_full!(t, alg_svd′) - trunc′ = _select_truncation(left_null!, t, trunc) + @show U, S, _ = svd_full!(t, alg_svd′) + @show trunc′ = _select_truncation(left_null!, t, @show trunc) return MatrixAlgebraKit.truncate!(left_null!, (U, S), trunc′) else throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) From 360a99110154472ef0bd8e9ef3ad6c8d2261e34b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 11 Jun 2025 21:02:54 -0400 Subject: [PATCH 030/126] Major overhaul --- .../factorizations.jl | 256 +++---- src/TensorKit.jl | 16 +- src/auxiliary/deprecate.jl | 46 +- src/spaces/vectorspaces.jl | 1 - src/tensors/backends.jl | 12 - src/tensors/diagonal.jl | 59 -- src/tensors/factorizations.jl | 707 ----------------- src/tensors/factorizations/deprecations.jl | 1 + src/tensors/factorizations/factorizations.jl | 215 ++++++ src/tensors/factorizations/implementations.jl | 167 ++++ src/tensors/factorizations/interface.jl | 242 ++++++ .../factorizations/matrixalgebrakit.jl | 508 +++++++++++++ src/tensors/factorizations/truncation.jl | 270 +++++++ src/tensors/factorizations/utility.jl | 29 + src/tensors/matrixalgebrakit.jl | 715 ------------------ src/tensors/truncation.jl | 163 ---- test/factorizations.jl | 554 +++++++------- test/tensors.jl | 49 +- 18 files changed, 1912 insertions(+), 2098 deletions(-) delete mode 100644 src/tensors/factorizations.jl create mode 100644 src/tensors/factorizations/deprecations.jl create mode 100644 src/tensors/factorizations/factorizations.jl create mode 100644 src/tensors/factorizations/implementations.jl create mode 100644 src/tensors/factorizations/interface.jl create mode 100644 src/tensors/factorizations/matrixalgebrakit.jl create mode 100644 src/tensors/factorizations/truncation.jl create mode 100644 src/tensors/factorizations/utility.jl delete mode 100644 src/tensors/matrixalgebrakit.jl delete mode 100644 src/tensors/truncation.jl diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index abfb724bc..a91c1dbc0 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,16 +1,17 @@ +using MatrixAlgebraKit: svd_compact_pullback! + # Factorizations rules # -------------------- function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; - trunc::TensorKit.TruncationScheme=TensorKit.NoTruncation(), - p::Real=2, + trunc::TensorKit.TruncationScheme=TensorKit.notrunc(), alg::Union{TensorKit.SVD,TensorKit.SDD}=TensorKit.SDD()) - U, Σ, V⁺, truncerr = tsvd(t; trunc=TensorKit.NoTruncation(), p=p, alg=alg) + U, Σ, V⁺, truncerr = tsvd(t; trunc=TensorKit.notrunc(), alg) - if !(trunc isa TensorKit.NoTruncation) && !isempty(blocksectors(t)) + if !(trunc == TensorKit.notrunc()) && !isempty(blocksectors(t)) Σdata = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(Σ)) - truncdim = TensorKit._compute_truncdim(Σdata, trunc, p) - truncerr = TensorKit._compute_truncerr(Σdata, truncdim, p) + truncdim = TensorKit._compute_truncdim(Σdata, trunc; p=2) + truncerr = TensorKit._compute_truncerr(Σdata, truncdim; p=2) SVDdata = TensorKit.SectorDict(c => (block(U, c), Σc, block(V⁺, c)) for (c, Σc) in Σdata) @@ -23,12 +24,11 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; function tsvd!_pullback(ΔUSVϵ) ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ) Δt = similar(t) - for (c, b) in blocks(Δt) - Uc, Σc, V⁺c = block(U, c), block(Σ, c), block(V⁺, c) - ΔUc, ΔΣc, ΔV⁺c = block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c) - Σdc = view(Σc, diagind(Σc)) - ΔΣdc = (ΔΣc isa AbstractZero) ? ΔΣc : view(ΔΣc, diagind(ΔΣc)) - svd_pullback!(b, Uc, Σdc, V⁺c, ΔUc, ΔΣdc, ΔV⁺c) + foreachblock(Δt) do (c, b) + USVᴴc = block(U, c), block(Σ, c), block(V⁺, c) + ΔUSVᴴc = block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c) + svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc) + return nothing end return NoTangent(), Δt end @@ -187,122 +187,122 @@ end # Other implementation considerations for GPU compatibility: # no scalar indexing, lots of broadcasting and views # -function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector, - Vd::AbstractMatrix, ΔU, ΔS, ΔVd; - tol::Real=default_pullback_gaugetol(S)) - - # Basic size checks and determination - m, n = size(U, 1), size(Vd, 2) - size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch()) - p = -1 - if !(ΔU isa AbstractZero) - m == size(ΔU, 1) || throw(DimensionMismatch()) - p = size(ΔU, 2) - end - if !(ΔVd isa AbstractZero) - n == size(ΔVd, 2) || throw(DimensionMismatch()) - if p == -1 - p = size(ΔVd, 1) - else - p == size(ΔVd, 1) || throw(DimensionMismatch()) - end - end - if !(ΔS isa AbstractZero) - if p == -1 - p = length(ΔS) - else - p == length(ΔS) || throw(DimensionMismatch()) - end - end - Up = view(U, :, 1:p) - Vp = view(Vd, 1:p, :)' - Sp = view(S, 1:p) - - # rank - r = searchsortedlast(S, tol; rev=true) - - # compute antihermitian part of projection of ΔU and ΔV onto U and V - # also already subtract this projection from ΔU and ΔV - if !(ΔU isa AbstractZero) - UΔU = Up' * ΔU - aUΔU = rmul!(UΔU - UΔU', 1 / 2) - if m > p - ΔU -= Up * UΔU - end - else - aUΔU = fill!(similar(U, (p, p)), 0) - end - if !(ΔVd isa AbstractZero) - VΔV = Vp' * ΔVd' - aVΔV = rmul!(VΔV - VΔV', 1 / 2) - if n > p - ΔVd -= VΔV' * Vp' - end - else - aVΔV = fill!(similar(Vd, (p, p)), 0) - end - - # check whether cotangents arise from gauge-invariance objective function - mask = abs.(Sp' .- Sp) .< tol - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) - if p > r - rprange = (r + 1):p - Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf)) - Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf)) - end - Δgauge < tol || - @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - - UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+ - (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol) - if !(ΔS isa ZeroTangent) - UdΔAV[diagind(UdΔAV)] .+= real.(ΔS) - # in principle, ΔS is real, but maybe not if coming from an anyonic tensor - end - mul!(ΔA, Up, UdΔAV * Vp') - - if r > p # contribution from truncation - Ur = view(U, :, (p + 1):r) - Vr = view(Vd, (p + 1):r, :)' - Sr = view(S, (p + 1):r) - - if !(ΔU isa AbstractZero) - UrΔU = Ur' * ΔU - if m > r - ΔU -= Ur * UrΔU # subtract this part from ΔU - end - else - UrΔU = fill!(similar(U, (r - p, p)), 0) - end - if !(ΔVd isa AbstractZero) - VrΔV = Vr' * ΔVd' - if n > r - ΔVd -= VrΔV' * Vr' # subtract this part from ΔV - end - else - VrΔV = fill!(similar(Vd, (r - p, p)), 0) - end - - X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+ - (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) - Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .- - (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) - - # ΔA += Ur * X * Vp' + Up * Y' * Vr' - mul!(ΔA, Ur, X * Vp', 1, 1) - mul!(ΔA, Up * Y', Vr', 1, 1) - end - - if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)] - # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp' - mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1) - end - if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)] - # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd) - mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1) - end - return ΔA -end +# function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector, +# Vd::AbstractMatrix, ΔU, ΔS, ΔVd; +# tol::Real=default_pullback_gaugetol(S)) + +# # Basic size checks and determination +# m, n = size(U, 1), size(Vd, 2) +# size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch()) +# p = -1 +# if !(ΔU isa AbstractZero) +# m == size(ΔU, 1) || throw(DimensionMismatch()) +# p = size(ΔU, 2) +# end +# if !(ΔVd isa AbstractZero) +# n == size(ΔVd, 2) || throw(DimensionMismatch()) +# if p == -1 +# p = size(ΔVd, 1) +# else +# p == size(ΔVd, 1) || throw(DimensionMismatch()) +# end +# end +# if !(ΔS isa AbstractZero) +# if p == -1 +# p = length(ΔS) +# else +# p == length(ΔS) || throw(DimensionMismatch()) +# end +# end +# Up = view(U, :, 1:p) +# Vp = view(Vd, 1:p, :)' +# Sp = view(S, 1:p) + +# # rank +# r = searchsortedlast(S, tol; rev=true) + +# # compute antihermitian part of projection of ΔU and ΔV onto U and V +# # also already subtract this projection from ΔU and ΔV +# if !(ΔU isa AbstractZero) +# UΔU = Up' * ΔU +# aUΔU = rmul!(UΔU - UΔU', 1 / 2) +# if m > p +# ΔU -= Up * UΔU +# end +# else +# aUΔU = fill!(similar(U, (p, p)), 0) +# end +# if !(ΔVd isa AbstractZero) +# VΔV = Vp' * ΔVd' +# aVΔV = rmul!(VΔV - VΔV', 1 / 2) +# if n > p +# ΔVd -= VΔV' * Vp' +# end +# else +# aVΔV = fill!(similar(Vd, (p, p)), 0) +# end + +# # check whether cotangents arise from gauge-invariance objective function +# mask = abs.(Sp' .- Sp) .< tol +# Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) +# if p > r +# rprange = (r + 1):p +# Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf)) +# Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf)) +# end +# Δgauge < tol || +# @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + +# UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+ +# (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol) +# if !(ΔS isa ZeroTangent) +# UdΔAV[diagind(UdΔAV)] .+= real.(ΔS) +# # in principle, ΔS is real, but maybe not if coming from an anyonic tensor +# end +# mul!(ΔA, Up, UdΔAV * Vp') + +# if r > p # contribution from truncation +# Ur = view(U, :, (p + 1):r) +# Vr = view(Vd, (p + 1):r, :)' +# Sr = view(S, (p + 1):r) + +# if !(ΔU isa AbstractZero) +# UrΔU = Ur' * ΔU +# if m > r +# ΔU -= Ur * UrΔU # subtract this part from ΔU +# end +# else +# UrΔU = fill!(similar(U, (r - p, p)), 0) +# end +# if !(ΔVd isa AbstractZero) +# VrΔV = Vr' * ΔVd' +# if n > r +# ΔVd -= VrΔV' * Vr' # subtract this part from ΔV +# end +# else +# VrΔV = fill!(similar(Vd, (r - p, p)), 0) +# end + +# X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+ +# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) +# Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .- +# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) + +# # ΔA += Ur * X * Vp' + Up * Y' * Vr' +# mul!(ΔA, Ur, X * Vp', 1, 1) +# mul!(ΔA, Up * Y', Vr', 1, 1) +# end + +# if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)] +# # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp' +# mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1) +# end +# if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)] +# # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd) +# mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1) +# end +# return ΔA +# end function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; tol::Real=default_pullback_gaugetol(D)) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 0e7a1cede..7eb3652ec 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -31,7 +31,7 @@ export TruncationScheme export SpaceMismatch, SectorMismatch, IndexError # error types # general vector space methods -export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, oplus, +export space, field, dual, dim, reduceddim, dims, fuse, flip, isdual, oplus, ominus, insertleftunit, insertrightunit, removeunit # partial order for vector spaces @@ -47,7 +47,7 @@ export ZNSpace, SU2Irrep, U1Irrep, CU1Irrep # bendleft, bendright, foldleft, foldright, cycleclockwise, cycleanticlockwise # some unicode -export ⊕, ⊗, ×, ⊠, ℂ, ℝ, ℤ, ←, →, ≾, ≿, ≅, ≺, ≻ +export ⊕, ⊗, ⊖, ×, ⊠, ℂ, ℝ, ℤ, ←, →, ≾, ≿, ≅, ≺, ≻ export ℤ₂, ℤ₃, ℤ₄, U₁, SU, SU₂, CU₁ export fℤ₂, fU₁, fSU₂ export ℤ₂Space, ℤ₃Space, ℤ₄Space, U₁Space, CU₁Space, SU₂Space @@ -70,8 +70,8 @@ export inner, dot, norm, normalize, normalize!, tr # factorizations export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby! -export leftorth, rightorth, leftnull, rightnull, - leftorth!, rightorth!, leftnull!, rightnull!, +export leftorth, rightorth, leftnull, rightnull, leftpolar, rightpolar, + leftorth!, rightorth!, leftnull!, rightnull!, leftpolar!, rightpolar!, tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, isposdef, isposdef!, ishermitian, isisometry, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, @@ -217,11 +217,13 @@ include("tensors/tensoroperations.jl") include("tensors/treetransformers.jl") include("tensors/indexmanipulations.jl") include("tensors/diagonal.jl") -include("tensors/truncation.jl") -include("tensors/matrixalgebrakit.jl") -include("tensors/factorizations.jl") include("tensors/braidingtensor.jl") +include("tensors/factorizations/factorizations.jl") +using .Factorizations +# include("tensors/factorizations/matrixalgebrakit.jl") +# include("tensors/truncation.jl") + # # Planar macros and related functionality # #----------------------------------------- @nospecialize diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index fa7667b2b..b235cbd7c 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -1,29 +1,29 @@ import Base: transpose #! format: off -Base.@deprecate(permute(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false), - permute(t, (p1, p2); copy=copy)) -Base.@deprecate(transpose(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false), - transpose(t, (p1, p2); copy=copy)) -Base.@deprecate(braid(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple, levels; copy::Bool=false), - braid(t, (p1, p2), levels; copy=copy)) - -Base.@deprecate(tsvd(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - tsvd(t, (p₁, p₂); kwargs...)) -Base.@deprecate(leftorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - leftorth(t, (p₁, p₂); kwargs...)) -Base.@deprecate(rightorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - rightorth(t, (p₁, p₂); kwargs...)) -Base.@deprecate(leftnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - leftnull(t, (p₁, p₂); kwargs...)) -Base.@deprecate(rightnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - rightnull(t, (p₁, p₂); kwargs...)) -Base.@deprecate(LinearAlgebra.eigen(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - LinearAlgebra.eigen(t, (p₁, p₂); kwargs...), false) -Base.@deprecate(eig(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - eig(t, (p₁, p₂); kwargs...)) -Base.@deprecate(eigh(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), - eigh(t, (p₁, p₂); kwargs...)) +# Base.@deprecate(permute(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false), +# permute(t, (p1, p2); copy=copy)) +# Base.@deprecate(transpose(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false), +# transpose(t, (p1, p2); copy=copy)) +# Base.@deprecate(braid(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple, levels; copy::Bool=false), +# braid(t, (p1, p2), levels; copy=copy)) + +# Base.@deprecate(tsvd(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# tsvd(t, (p₁, p₂); kwargs...)) +# Base.@deprecate(leftorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# leftorth(t, (p₁, p₂); kwargs...)) +# Base.@deprecate(rightorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# rightorth(t, (p₁, p₂); kwargs...)) +# Base.@deprecate(leftnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# leftnull(t, (p₁, p₂); kwargs...)) +# Base.@deprecate(rightnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# rightnull(t, (p₁, p₂); kwargs...)) +# Base.@deprecate(LinearAlgebra.eigen(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# LinearAlgebra.eigen(t, (p₁, p₂); kwargs...), false) +# Base.@deprecate(eig(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# eig(t, (p₁, p₂); kwargs...)) +# Base.@deprecate(eigh(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), +# eigh(t, (p₁, p₂); kwargs...)) for f in (:rand, :randn, :zeros, :ones) @eval begin diff --git a/src/spaces/vectorspaces.jl b/src/spaces/vectorspaces.jl index 3b903e19d..92fd780ef 100644 --- a/src/spaces/vectorspaces.jl +++ b/src/spaces/vectorspaces.jl @@ -416,4 +416,3 @@ have the same value. function supremum(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} return supremum(supremum(V₁, V₂), V₃...) end - diff --git a/src/tensors/backends.jl b/src/tensors/backends.jl index 1fc970e72..1083115b9 100644 --- a/src/tensors/backends.jl +++ b/src/tensors/backends.jl @@ -30,15 +30,3 @@ end # TODO: disable for trivial symmetry or small tensors? default_blockscheduler(t::AbstractTensorMap) = default_blockscheduler(typeof(t)) default_blockscheduler(::Type{T}) where {T<:AbstractTensorMap} = blockscheduler[] - -# MatrixAlgebraKit -# ---------------- -""" - BlockAlgorithm{A,S}(alg, scheduler) - -Generic wrapper for implementing block-wise algorithms. -""" -struct BlockAlgorithm{A,S} <: MatrixAlgebraKit.AbstractAlgorithm - alg::A - scheduler::S -end diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 88d0c3b25..5a3840f1b 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -317,65 +317,6 @@ function LinearAlgebra.isposdef(d::DiagonalTensorMap) return all(isposdef, d.data) end -function eig!(d::DiagonalTensorMap) - return d, one(d) -end -function eigh!(d::DiagonalTensorMap{<:Real}) - return d, one(d) -end -function eigh!(d::DiagonalTensorMap{<:Complex}) - # TODO: should this test for hermiticity? `eigh!(::TensorMap)` also does not do this. - return DiagonalTensorMap(real(d.data), d.domain), one(d) -end - -function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...) - @assert alg isa Union{QR,QL} - return one(d), d # TODO: this is only correct for `alg = QR()` or `alg = QL()` -end -function rightorth!(d::DiagonalTensorMap; alg=LQ(), kwargs...) - @assert alg isa Union{LQ,RQ} - return d, one(d) # TODO: this is only correct for `alg = LQ()` or `alg = RQ()` -end -# not much to do here: -leftnull!(d::DiagonalTensorMap; kwargs...) = leftnull!(TensorMap(d); kwargs...) -rightnull!(d::DiagonalTensorMap; kwargs...) = rightnull!(TensorMap(d); kwargs...) - -function tsvd!(d::DiagonalTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) - return _tsvd!(d, alg, trunc, p) -end -# helper function -function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD}) - InnerProductStyle(d) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) - I = sectortype(d) - dims = SectorDict{I,Int}() - generator = Base.Iterators.map(blocks(d)) do (c, b) - lb = length(b.diag) - U = zerovector!(similar(b.diag, lb, lb)) - V = zerovector!(similar(b.diag, lb, lb)) - p = sortperm(b.diag; by=abs, rev=true) - for (i, pi) in enumerate(p) - U[pi, i] = MatrixAlgebra.safesign(b.diag[pi]) - V[i, pi] = 1 - end - Σ = abs.(view(b.diag, p)) - dims[c] = lb - return c => (U, Σ, V) - end - SVDdata = SectorDict(generator) - return SVDdata, dims -end - -function LinearAlgebra.svdvals(d::DiagonalTensorMap) - return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d)) -end -function LinearAlgebra.eigvals(d::DiagonalTensorMap) - return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d)) -end - -function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2) - return LinearAlgebra.cond(Diagonal(d.data), p) -end - # matrix functions for f in (:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt, diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl deleted file mode 100644 index 4c77814b4..000000000 --- a/src/tensors/factorizations.jl +++ /dev/null @@ -1,707 +0,0 @@ -# Tensor factorization -#---------------------- -function factorisation_scalartype(t::AbstractTensorMap) - T = scalartype(t) - return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) -end -factorisation_scalartype(f, t) = factorisation_scalartype(t) - -function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple) - return permute!(similar(t, T, permute(space(t), p)), t, p) -end -function copy_oftype(t::AbstractTensorMap, T::Type{<:Number}) - return copy!(similar(t, T), t) -end - -""" - tsvd(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple; - trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD()) - -> U, S, V, ϵ - -Compute the (possibly truncated) singular value decomposition such that -`norm(permute(t, (leftind, rightind)) - U * S * V) ≈ ϵ`, where `ϵ` thus represents the truncation error. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `tsvd!(t, trunc = notrunc(), p = 2)`. - -A truncation parameter `trunc` can be specified for the new internal dimension, in which -case a truncated singular value decomposition will be computed. Choices are: -* `notrunc()`: no truncation (default); -* `truncerr(η::Real)`: truncates such that the p-norm of the truncated singular values is - smaller than `η`; -* `truncdim(χ::Int)`: truncates such that the equivalent total dimension of the internal - vector space is no larger than `χ`; -* `truncspace(V)`: truncates such that the dimension of the internal vector space is no - greater than that of `V` in any sector. -* `truncbelow(η::Real)`: truncates such that every singular value is larger then `η` ; - -Truncation options can also be combined using `&`, i.e. `truncbelow(η) & truncdim(χ)` will -choose the truncation space such that every singular value is larger than `η`, and the -equivalent total dimension of the internal vector space is no larger than `χ`. - -The method `tsvd` also returns the truncation error `ϵ`, computed as the `p` norm of the -singular values that were truncated. - -THe keyword `alg` can be equal to `SVD()` or `SDD()`, corresponding to the underlying LAPACK -algorithm that computes the decomposition (`_gesvd` or `_gesdd`). - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)` -is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`. -""" -function tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p) - return tsvd!(tcopy; kwargs...) -end - -function LinearAlgebra.svdvals(t::AbstractTensorMap) - tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t)) - return LinearAlgebra.svdvals!(tcopy) -end - -""" - leftorth(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple; - alg::OrthogonalFactorizationAlgorithm = QRpos()) -> Q, R - -Create orthonormal basis `Q` for indices in `leftind`, and remainder `R` such that -`permute(t, (leftind, rightind)) = Q*R`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `leftorth!(t, alg = QRpos())`. - -Different algorithms are available, namely `QR()`, `QRpos()`, `SVD()` and `Polar()`. `QR()` -and `QRpos()` use a standard QR decomposition, producing an upper triangular matrix `R`. -`Polar()` produces a Hermitian and positive semidefinite `R`. `QRpos()` corrects the -standard QR decomposition such that the diagonal elements of `R` are positive. Only -`QRpos()` and `Polar()` are unique (no residual freedom) so that they always return the same -result for the same input tensor `t`. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`leftorth(!)` is currently only implemented for - `InnerProductStyle(t) === EuclideanInnerProduct()`. -""" -function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p) - return leftorth!(tcopy; kwargs...) -end - -""" - rightorth(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple; - alg::OrthogonalFactorizationAlgorithm = LQpos()) -> L, Q - -Create orthonormal basis `Q` for indices in `rightind`, and remainder `L` such that -`permute(t, (leftind, rightind)) = L*Q`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `rightorth!(t, alg = LQpos())`. - -Different algorithms are available, namely `LQ()`, `LQpos()`, `RQ()`, `RQpos()`, `SVD()` and -`Polar()`. `LQ()` and `LQpos()` produce a lower triangular matrix `L` and are computed using -a QR decomposition of the transpose. `RQ()` and `RQpos()` produce an upper triangular -remainder `L` and only works if the total left dimension is smaller than or equal to the -total right dimension. `LQpos()` and `RQpos()` add an additional correction such that the -diagonal elements of `L` are positive. `Polar()` produces a Hermitian and positive -semidefinite `L`. Only `LQpos()`, `RQpos()` and `Polar()` are unique (no residual freedom) -so that they always return the same result for the same input tensor `t`. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`rightorth(!)` is currently only implemented for -`InnerProductStyle(t) === EuclideanInnerProduct()`. -""" -function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p) - return rightorth!(tcopy; kwargs...) -end - -""" - leftnull(t::AbstractTensor, (leftind, rightind)::Index2Tuple; - alg::OrthogonalFactorizationAlgorithm = QRpos()) -> N - -Create orthonormal basis for the orthogonal complement of the support of the indices in -`leftind`, such that `N' * permute(t, (leftind, rightind)) = 0`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `leftnull!(t, alg = QRpos())`. - -Different algorithms are available, namely `QR()` (or equivalently, `QRpos()`), `SVD()` and -`SDD()`. The first assumes that the matrix is full rank and requires `iszero(atol)` and -`iszero(rtol)`. With `SVD()` and `SDD()`, `rightnull` will use the corresponding singular -value decomposition, and one can specify an absolute or relative tolerance for which -singular values are to be considered zero, where `max(atol, norm(t)*rtol)` is used as upper -bound. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`leftnull(!)` is currently only implemented for -`InnerProductStyle(t) === EuclideanInnerProduct()`. -""" -function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p) - return leftnull!(tcopy; kwargs...) -end - -""" - rightnull(t::AbstractTensor, (leftind, rightind)::Index2Tuple; - alg::OrthogonalFactorizationAlgorithm = LQ(), - atol::Real = 0.0, - rtol::Real = eps(real(float(one(scalartype(t)))))*iszero(atol)) -> N - -Create orthonormal basis for the orthogonal complement of the support of the indices in -`rightind`, such that `permute(t, (leftind, rightind))*N' = 0`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `rightnull!(t, alg = LQpos())`. - -Different algorithms are available, namely `LQ()` (or equivalently, `LQpos`), `SVD()` and -`SDD()`. The first assumes that the matrix is full rank and requires `iszero(atol)` and -`iszero(rtol)`. With `SVD()` and `SDD()`, `rightnull` will use the corresponding singular -value decomposition, and one can specify an absolute or relative tolerance for which -singular values are to be considered zero, where `max(atol, norm(t)*rtol)` is used as upper -bound. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`rightnull(!)` is currently only implemented for -`InnerProductStyle(t) === EuclideanInnerProduct()`. -""" -function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p) - return rightnull!(tcopy; kwargs...) -end - -""" - eigen(t::AbstractTensor, (leftind, rightind)::Index2Tuple; kwargs...) -> D, V - -Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `eigen!(t)`. Note that the permuted tensor on which -`eigen!` is called should have equal domain and codomain, as otherwise the eigenvalue -decomposition is meaningless and cannot satisfy -``` -permute(t, (leftind, rightind)) * V = V * D -``` - -Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense -matrices. See the corresponding documentation for more information. - -See also `eig` and `eigh` -""" -function LinearAlgebra.eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p) - return eigen!(tcopy; kwargs...) -end - -function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...) - tcopy = copy_oftype(t, factorisation_scalartype(eigen, t)) - return LinearAlgebra.eigvals!(tcopy; kwargs...) -end - -""" - eig(t::AbstractTensor, (leftind, rightind)::Index2Tuple; kwargs...) -> D, V - -Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. -The function `eig` assumes that the linear map is not hermitian and returns type stable -complex valued `D` and `V` tensors for both real and complex valued `t`. See `eigh` for -hermitian linear maps - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `eig!(t)`. Note that the permuted tensor on -which `eig!` is called should have equal domain and codomain, as otherwise the eigenvalue -decomposition is meaningless and cannot satisfy -``` -permute(t, (leftind, rightind)) * V = V * D -``` - -Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense -matrices. See the corresponding documentation for more information. - -See also `eigen` and `eigh`. -""" -function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(eig, t), p) - return eig!(tcopy; kwargs...) -end - -""" - eigh(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple) -> D, V - -Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. -The function `eigh` assumes that the linear map is hermitian and `D` and `V` tensors with -the same `scalartype` as `t`. See `eig` and `eigen` for non-hermitian tensors. Hermiticity -requires that the tensor acts on inner product spaces, and the current implementation -requires `InnerProductStyle(t) === EuclideanInnerProduct()`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `eigh!(t)`. Note that the permuted tensor on -which `eigh!` is called should have equal domain and codomain, as otherwise the eigenvalue -decomposition is meaningless and cannot satisfy -``` -permute(t, (leftind, rightind)) * V = V * D -``` - -See also `eigen` and `eig`. -""" -function eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigh, t), p) - return eigh!(tcopy; kwargs...) -end - -""" - isposdef(t::AbstractTensor, (leftind, rightind)::Index2Tuple) -> ::Bool - -Test whether a tensor `t` is positive definite as linear map from `rightind` to `leftind`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `isposdef!(t)`. Note that the permuted tensor on -which `isposdef!` is called should have equal domain and codomain, as otherwise it is -meaningless. -""" -function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(isposdef, t), p) - return isposdef!(tcopy) -end - -function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) - t = permute(t, (p₁, p₂); copy=false) - return isisometry(t) -end - -function tsvd(t::AbstractTensorMap; kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return tsvd!(tcopy; kwargs...) -end -function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return leftorth!(tcopy; alg=alg, kwargs...) -end -function rightorth(t::AbstractTensorMap; alg::OFA=LQpos(), kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return rightorth!(tcopy; alg=alg, kwargs...) -end -function leftnull(t::AbstractTensorMap; alg::OFA=QR(), kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return leftnull!(tcopy; alg=alg, kwargs...) -end -function rightnull(t::AbstractTensorMap; alg::OFA=LQ(), kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return rightnull!(tcopy; alg=alg, kwargs...) -end -function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return eigen!(tcopy; kwargs...) -end -function eig(t::AbstractTensorMap; kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return eig!(tcopy; kwargs...) -end -function eigh(t::AbstractTensorMap; kwargs...) - tcopy = copy_oftype(t, float(scalartype(t))) - return eigh!(tcopy; kwargs...) -end -function LinearAlgebra.isposdef(t::AbstractTensorMap) - tcopy = copy_oftype(t, float(scalartype(t))) - return isposdef!(tcopy) -end - -# Orthogonal factorizations (mutation for recycling memory): -# only possible if scalar type is floating point -# only correct if Euclidean inner product -#------------------------------------------------------------------------------------------ -const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} - -function _reverse!(t::AbstractTensorMap; dims=:) - for (c, b) in blocks(t) - reverse!(b; dims) - end - return t -end - -function leftorth!(t::TensorMap{<:RealOrComplexFloat}; - alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar,Nothing}=nothing, - kwargs...) - # atol::Real=zero(float(real(scalartype(t)))), - # rtol::Real=(alg ∉ (SVD(), SDD())) ? - # zero(float(real(scalartype(t)))) : - # eps(real(float(one(scalartype(t))))) * - # iszero(atol)) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftorth!) - return _leftorth!(t, alg; kwargs...) - - # if alg == SVD() || alg == SDD() - # return _leftorth!(t, alg; atol, rtol) - # else - # (iszero(atol) && iszero(rtol)) || - # throw(ArgumentError("`leftorth!` with nonzero atol or rtol requires SVD or SDD algorithm")) - # return _leftorth!(t, alg) - # end -end - -# this promotes the algorithm to a positional argument for type stability reasons -# since polar has different number of output legs -# TODO: this seems like duplication from MatrixAlgebraKit.left_orth!, but that function -# only has its logic with the output already specified, which breaks for polar -function _leftorth!(t::TensorMap{<:RealOrComplexFloat}, alg; kwargs...) - trunc = isempty(kwargs) ? nothing : (; kwargs...) - if isnothing(alg) - return left_orth!(t; trunc) - elseif alg == SVD() - return left_orth!(t; kind=:svd, alg_svd=:LAPACK_QRIteration, trunc) - elseif alg == SDD() - return left_orth!(t; kind=:svd, alg_svd=:LAPACK_DivideAndConquer, trunc) - elseif alg == QR() - return left_orth!(t; kind=:qr, alg_qr=(; positive=false), trunc) - elseif alg == QRpos() - return left_orth!(t; kind=:qr, alg_qr=(; positive=true), trunc) - elseif alg == QL() || alg == QLpos() - _reverse!(t; dims=2) - Q, R = left_orth!(t; kind=:qr, alg_qr=(; positive=alg == QLpos()), trunc) - _reverse!(Q; dims=2) - _reverse!(R) - return Q, R - elseif alg == Polar() - return left_orth!(t; kind=:polar, trunc) - else - throw(ArgumentError(lazy"Invalid algorithm: $alg")) - end -end - - -function leftnull!(t::TensorMap{<:RealOrComplexFloat}; - alg::Union{QR,QRpos,SVD,SDD}=QRpos(), - atol::Real=zero(float(real(scalartype(t)))), - rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) : - eps(real(float(one(scalartype(t))))) * iszero(atol)) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftnull!) - - if alg == SVD() || alg == SDD() - kind = :svd - alg_svd = BlockAlgorithm(alg == SVD() ? MatrixAlgebraKit.LAPACK_QRIteration() : - MatrixAlgebraKit.LAPACK_DivideAndConquer(), - default_blockscheduler(t)) - trunc = if iszero(atol) && iszero(rtol) - nothing - else - (; atol, rtol) - end - return left_null!(t; kind, alg_svd, trunc) - end - - (iszero(atol) && iszero(rtol)) || - throw(ArgumentError("`leftnull!` with nonzero atol or rtol requires SVD or SDD algorithm")) - - kind = :qr - alg_qr = (; positive=alg == QRpos()) - return left_null!(t; kind, alg_qr) -end - -function rightorth!(t::TensorMap{<:RealOrComplexFloat}; - alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar}=LQpos(), - atol::Real=zero(float(real(scalartype(t)))), - rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) : - eps(real(float(one(scalartype(t))))) * iszero(atol)) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightorth!) - if !iszero(rtol) - atol = max(atol, rtol * norm(t)) - end - I = sectortype(t) - dims = SectorDict{I,Int}() - - # compute LQ factorization for each block - if !isempty(blocks(t)) - generator = Base.Iterators.map(blocks(t)) do (c, b) - Lc, Qc = MatrixAlgebra.rightorth!(b, alg, atol) - dims[c] = size(Qc, 1) - return c => (Lc, Qc) - end - LQdata = SectorDict(generator) - end - - # construct new space - S = spacetype(t) - V = S(dims) - if alg isa Polar - @assert V ≅ codomain(t) - W = codomain(t) - elseif length(codomain(t)) == 1 && codomain(t) ≅ V - W = codomain(t) - elseif length(domain(t)) == 1 && domain(t) ≅ V - W = domain(t) - else - W = ProductSpace(V) - end - - # construct output tensors - T = float(scalartype(t)) - L = similar(t, T, codomain(t) ← W) - Q = similar(t, T, W ← domain(t)) - if !isempty(blocks(t)) - for (c, (Lc, Qc)) in LQdata - copy!(block(L, c), Lc) - copy!(block(Q, c), Qc) - end - end - return L, Q -end - -function rightnull!(t::TensorMap{<:RealOrComplexFloat}; - alg::Union{LQ,LQpos,SVD,SDD}=LQpos(), - atol::Real=zero(float(real(scalartype(t)))), - rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) : - eps(real(float(one(scalartype(t))))) * iszero(atol)) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightnull!) - if !iszero(rtol) - atol = max(atol, rtol * norm(t)) - end - I = sectortype(t) - dims = SectorDict{I,Int}() - - # compute LQ factorization for each block - V = domain(t) - if !isempty(blocksectors(V)) - generator = Base.Iterators.map(blocksectors(V)) do c - Nc = MatrixAlgebra.rightnull!(block(t, c), alg, atol) - dims[c] = size(Nc, 1) - return c => Nc - end - Ndata = SectorDict(generator) - end - - # construct new space - S = spacetype(t) - W = S(dims) - - # construct output tensor - T = float(scalartype(t)) - N = similar(t, T, W ← V) - if !isempty(blocksectors(V)) - for (c, Nc) in Ndata - copy!(block(N, c), Nc) - end - end - return N -end - -function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftorth!) - return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg'))) -end - -function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightorth!) - return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg'))) -end - -function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftnull!) - return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...)) -end - -function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightnull!) - return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...)) -end - -#------------------------------# -# Singular value decomposition # -#------------------------------# -function LinearAlgebra.svdvals!(t::TensorMap{<:RealOrComplexFloat}) - return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t)) -end -LinearAlgebra.svdvals!(t::AdjointTensorMap) = svdvals!(adjoint(t)) - -function tsvd!(t::TensorMap{<:RealOrComplexFloat}; - trunc=NoTruncation(), p::Real=2, alg=SDD()) - return _tsvd!(t, alg, trunc, p) -end -function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) - u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) - return adjoint(vt), adjoint(s), adjoint(u), err -end - -# implementation dispatches on algorithm -function _tsvd!(t::TensorMap{<:BlasFloat}, alg::Union{SVD,SDD}, - ::NoTruncation, p::Real=2) - scheduler = default_blockscheduler(t) - svd_alg = alg isa SDD ? LAPACK_DivideAndConquer() : LAPACK_QRIteration() - return MatrixAlgebraKit.svd_compact!(t; alg=BlockAlgorithm(svd_alg, scheduler))..., - zero(real(scalartype(t))) -end -function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD}, - trunc::TruncationScheme, p::Real=2) - # early return - if isempty(blocksectors(t)) - truncerr = zero(real(scalartype(t))) - return _empty_svdtensors(t)..., truncerr - end - - # compute SVD factorization for each block - S = spacetype(t) - SVDdata, dims = _compute_svddata!(t, alg) - Σdata = SectorDict(c => Σ for (c, (U, Σ, V)) in SVDdata) - truncdim = _compute_truncdim(Σdata, trunc, p) - truncerr = _compute_truncerr(Σdata, truncdim, p) - - # construct output tensors - U, Σ, V⁺ = _create_svdtensors(t, SVDdata, truncdim) - return U, Σ, V⁺, truncerr -end - -# helper functions -function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}) - InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) - I = sectortype(t) - dims = SectorDict{I,Int}() - generator = Base.Iterators.map(blocks(t)) do (c, b) - U, Σ, V = MatrixAlgebra.svd!(b, alg) - dims[c] = length(Σ) - return c => (U, Σ, V) - end - SVDdata = SectorDict(generator) - return SVDdata, dims -end - -function _create_svdtensors(t::TensorMap{<:RealOrComplexFloat}, SVDdata, dims) - T = scalartype(t) - S = spacetype(t) - W = S(dims) - - Tr = real(T) - A = similarstoragetype(t, Tr) - Σ = DiagonalTensorMap{Tr,S,A}(undef, W) - - U = similar(t, codomain(t) ← W) - V⁺ = similar(t, W ← domain(t)) - for (c, (Uc, Σc, V⁺c)) in SVDdata - r = Base.OneTo(dims[c]) - copy!(block(U, c), view(Uc, :, r)) - copy!(block(Σ, c), Diagonal(view(Σc, r))) - copy!(block(V⁺, c), view(V⁺c, r, :)) - end - return U, Σ, V⁺ -end - -function _empty_svdtensors(t::TensorMap{<:RealOrComplexFloat}) - T = scalartype(t) - S = spacetype(t) - I = sectortype(t) - dims = SectorDict{I,Int}() - W = S(dims) - - Tr = real(T) - A = similarstoragetype(t, Tr) - Σ = DiagonalTensorMap{Tr,S,A}(undef, W) - - U = similar(t, codomain(t) ← W) - V⁺ = similar(t, W ← domain(t)) - return U, Σ, V⁺ -end - -#--------------------------# -# Eigenvalue decomposition # -#--------------------------# -function LinearAlgebra.eigen!(t::TensorMap{<:RealOrComplexFloat}) - return ishermitian(t) ? eigh!(t) : eig!(t) -end - -function LinearAlgebra.eigvals!(t::TensorMap{<:RealOrComplexFloat}; kwargs...) - return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...)) - for (c, b) in blocks(t)) -end -function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwargs...) - return SectorDict(c => conj!(complex(LinearAlgebra.eigvals!(b; kwargs...))) - for (c, b) in blocks(t)) -end - -eigh!(t::TensorMap{<:RealOrComplexFloat}) = eigh_full!(t) -eig!(t::TensorMap{<:RealOrComplexFloat}) = eig_full!(t) - -# function eigh!(t::TensorMap{<:RealOrComplexFloat}) -# InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) -# domain(t) == codomain(t) || -# throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same")) - -# T = scalartype(t) -# I = sectortype(t) -# S = spacetype(t) -# dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t)) -# W = S(dims) - -# Tr = real(T) -# A = similarstoragetype(t, Tr) -# D = DiagonalTensorMap{Tr,S,A}(undef, W) -# V = similar(t, domain(t) ← W) -# for (c, b) in blocks(t) -# values, vectors = MatrixAlgebra.eigh!(b) -# copy!(block(D, c), Diagonal(values)) -# copy!(block(V, c), vectors) -# end -# return D, V -# end - -# function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...) -# domain(t) == codomain(t) || -# throw(SpaceMismatch("`eig!` requires domain and codomain to be the same")) - -# T = scalartype(t) -# I = sectortype(t) -# S = spacetype(t) -# dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t)) -# W = S(dims) - -# Tc = complex(T) -# A = similarstoragetype(t, Tc) -# D = DiagonalTensorMap{Tc,S,A}(undef, W) -# V = similar(t, Tc, domain(t) ← W) -# for (c, b) in blocks(t) -# values, vectors = MatrixAlgebra.eig!(b; kwargs...) -# copy!(block(D, c), Diagonal(values)) -# copy!(block(V, c), vectors) -# end -# return D, V -# end - -#--------------------------------------------------# -# Checks for hermiticity and positive definiteness # -#--------------------------------------------------# -function LinearAlgebra.ishermitian(t::TensorMap) - domain(t) == codomain(t) || return false - InnerProductStyle(t) === EuclideanInnerProduct() || return false # hermiticity only defined for euclidean - for (c, b) in blocks(t) - ishermitian(b) || return false - end - return true -end - -function LinearAlgebra.isposdef!(t::TensorMap) - domain(t) == codomain(t) || - throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) - InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false - for (c, b) in blocks(t) - isposdef!(b) || return false - end - return true -end - -# TODO: tolerances are per-block, not global or weighted - does that matter? -function isisometry(t::AbstractTensorMap; kwargs...) - domain(t) ≾ codomain(t) || return false - for (_, b) in blocks(t) - MatrixAlgebra.isisometry(b; kwargs...) || return false - end - return true -end diff --git a/src/tensors/factorizations/deprecations.jl b/src/tensors/factorizations/deprecations.jl new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/tensors/factorizations/deprecations.jl @@ -0,0 +1 @@ + diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl new file mode 100644 index 000000000..a83b77a35 --- /dev/null +++ b/src/tensors/factorizations/factorizations.jl @@ -0,0 +1,215 @@ +# Tensor factorization +#---------------------- +# using submodule here to import MatrixAlgebraKit functions without polluting namespace +module Factorizations + +export eig, eig!, eigh, eigh! +export tsvd, tsvd!, svdvals +export leftorth, leftorth!, rightorth, rightorth! +export leftnull, leftnull!, rightnull, rightnull! +export leftpolar, leftpolar!, rightpolar, rightpolar! +export copy_oftype, permutedcopy_oftype +export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace + +using ..TensorKit +using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock +using ..MatrixAlgebra: MatrixAlgebra + +using LinearAlgebra: LinearAlgebra, BlasFloat +import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian + +using TensorOperations: Index2Tuple + +using MatrixAlgebraKit +using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy, + NoTruncation, TruncationKeepAbove, TruncationKeepBelow, + TruncationIntersection, TruncationKeepFiltered +import MatrixAlgebraKit: select_algorithm, + default_qr_algorithm, default_lq_algorithm, + default_eig_algorithm, default_eigh_algorithm, + default_svd_algorithm, default_polar_algorithm, + copy_input, check_input, initialize_output, + qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, + svd_compact!, svd_full!, svd_trunc!, + eig_full!, eig_trunc!, eigh_full!, eigh_trunc!, + left_polar!, left_orth_polar!, right_polar!, right_orth_polar!, + left_null_svd!, right_null_svd!, + left_orth!, right_orth!, left_null!, right_null!, + truncate!, findtruncated, findtruncated_sorted, + diagview + +include("utility.jl") +include("interface.jl") +include("implementations.jl") +include("matrixalgebrakit.jl") +include("truncation.jl") +include("deprecations.jl") + +""" + isposdef(t::AbstractTensor, (leftind, rightind)::Index2Tuple) -> ::Bool + +Test whether a tensor `t` is positive definite as linear map from `rightind` to `leftind`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in +`t` to be destroyed/overwritten, by using `isposdef!(t)`. Note that the permuted tensor on +which `isposdef!` is called should have equal domain and codomain, as otherwise it is +meaningless. +""" +function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) + tcopy = permutedcopy_oftype(t, factorisation_scalartype(isposdef, t), p) + return isposdef!(tcopy) +end +function LinearAlgebra.isposdef(t::AbstractTensorMap) + tcopy = copy_oftype(t, float(scalartype(t))) + return isposdef!(tcopy) +end + +function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) + t = permute(t, (p₁, p₂); copy=false) + return isisometry(t) +end + +# Orthogonal factorizations (mutation for recycling memory): +# only possible if scalar type is floating point +# only correct if Euclidean inner product +#------------------------------------------------------------------------------------------ +const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} + +# AdjointTensorMap +# ---------------- +function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos()) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:leftorth!) + return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg'))) +end + +function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos()) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightorth!) + return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg'))) +end + +function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:leftnull!) + return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...)) +end + +function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightnull!) + return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...)) +end + +function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) + u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) + return adjoint(vt), adjoint(s), adjoint(u), err +end + +# DiagonalTensorMap +# ----------------- +function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...) + @assert alg isa Union{QR,QL} + return one(d), d # TODO: this is only correct for `alg = QR()` or `alg = QL()` +end +function rightorth!(d::DiagonalTensorMap; alg=LQ(), kwargs...) + @assert alg isa Union{LQ,RQ} + return d, one(d) # TODO: this is only correct for `alg = LQ()` or `alg = RQ()` +end +leftnull!(d::DiagonalTensorMap; kwargs...) = leftnull!(TensorMap(d); kwargs...) +rightnull!(d::DiagonalTensorMap; kwargs...) = rightnull!(TensorMap(d); kwargs...) + +function tsvd!(d::DiagonalTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) + return _tsvd!(d, alg, trunc, p) +end + +# helper function +function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD}) + InnerProductStyle(d) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) + I = sectortype(d) + dims = SectorDict{I,Int}() + generator = Base.Iterators.map(blocks(d)) do (c, b) + lb = length(b.diag) + U = zerovector!(similar(b.diag, lb, lb)) + V = zerovector!(similar(b.diag, lb, lb)) + p = sortperm(b.diag; by=abs, rev=true) + for (i, pi) in enumerate(p) + U[pi, i] = MatrixAlgebra.safesign(b.diag[pi]) + V[i, pi] = 1 + end + Σ = abs.(view(b.diag, p)) + dims[c] = lb + return c => (U, Σ, V) + end + SVDdata = SectorDict(generator) + return SVDdata, dims +end + +eig!(d::DiagonalTensorMap) = d, one(d) +eigh!(d::DiagonalTensorMap{<:Real}) = d, one(d) +eigh!(d::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(real(d.data), d.domain), one(d) + +function LinearAlgebra.svdvals(d::DiagonalTensorMap) + return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d)) +end +function LinearAlgebra.eigvals(d::DiagonalTensorMap) + return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d)) +end + +function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2) + return LinearAlgebra.cond(Diagonal(d.data), p) +end +#------------------------------# +# Singular value decomposition # +#------------------------------# +function LinearAlgebra.svdvals!(t::TensorMap{<:RealOrComplexFloat}) + return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t)) +end +LinearAlgebra.svdvals!(t::AdjointTensorMap) = svdvals!(adjoint(t)) + +#--------------------------# +# Eigenvalue decomposition # +#--------------------------# + +function LinearAlgebra.eigvals!(t::TensorMap{<:RealOrComplexFloat}; kwargs...) + return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...)) + for (c, b) in blocks(t)) +end +function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwargs...) + return SectorDict(c => conj!(complex(LinearAlgebra.eigvals!(b; kwargs...))) + for (c, b) in blocks(t)) +end + +#--------------------------------------------------# +# Checks for hermiticity and positive definiteness # +#--------------------------------------------------# +function LinearAlgebra.ishermitian(t::TensorMap) + domain(t) == codomain(t) || return false + InnerProductStyle(t) === EuclideanInnerProduct() || return false # hermiticity only defined for euclidean + for (c, b) in blocks(t) + ishermitian(b) || return false + end + return true +end + +function LinearAlgebra.isposdef!(t::TensorMap) + domain(t) == codomain(t) || + throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) + InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false + for (c, b) in blocks(t) + isposdef!(b) || return false + end + return true +end + +# TODO: tolerances are per-block, not global or weighted - does that matter? +function isisometry(t::AbstractTensorMap; kwargs...) + domain(t) ≾ codomain(t) || return false + for (_, b) in blocks(t) + MatrixAlgebra.isisometry(b; kwargs...) || return false + end + return true +end + +end diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl new file mode 100644 index 000000000..89c0d0a00 --- /dev/null +++ b/src/tensors/factorizations/implementations.jl @@ -0,0 +1,167 @@ +_kindof(::Union{SVD,SDD}) = :svd +_kindof(::Union{QR,QRpos}) = :qr +_kindof(::Union{LQ,LQpos}) = :lq +_kindof(::Polar) = :polar + +for f! in (:svd_compact!, :svd_full!, :left_null_svd!, :right_null_svd!) + @eval function select_algorithm(::typeof($f!), t::T, alg::SVD; + kwargs...) where {T} + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed")) + return LAPACK_QRIteration() + end + @eval function select_algorithm(::typeof($f!), t::AbstractTensorMap, alg::SVD; + kwargs...) + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed")) + return LAPACK_QRIteration() + end + @eval function select_algorithm(::typeof($f!), ::Type{T}, alg::SVD; + kwargs...) where {T<:AbstractTensorMap} + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed")) + return LAPACK_QRIteration() + end + @eval function select_algorithm(::typeof($f!), t::T, alg::SDD; + kwargs...) where {T} + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed")) + return LAPACK_DivideAndConquer() + end + @eval function select_algorithm(::typeof($f!), t::AbstractTensorMap, alg::SDD; + kwargs...) + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed")) + return LAPACK_DivideAndConquer() + end + @eval function select_algorithm(::typeof($f!), ::Type{T}, alg::SDD; + kwargs...) where {T<:AbstractTensorMap} + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed")) + return LAPACK_DivideAndConquer() + end +end + +leftorth!(t::AbstractTensorMap; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) + +function _leftorth!(t::AbstractTensorMap, ::Nothing; kwargs...) + return isempty(kwargs) ? left_orth!(t) : left_orth!(t; trunc=(; kwargs...)) +end +function _leftorth!(t::AbstractTensorMap, alg::Union{QL,QLpos}; kwargs...) + trunc = isempty(kwargs) ? nothing : (; kwargs...) + + if alg == QL() || alg == QLpos() + _reverse!(t; dims=2) + Q, R = left_orth!(t; kind=:qr, alg_qr=(; positive=alg == QLpos()), trunc) + _reverse!(Q; dims=2) + _reverse!(R) + return Q, R + end +end +function _leftorth!(t, alg::OFA; kwargs...) + trunc = isempty(kwargs) ? nothing : (; kwargs...) + + kind = _kindof(alg) + if kind == :svd + return left_orth!(t; kind, alg_svd=alg, trunc) + elseif kind == :qr + alg_qr = (; positive=(alg == QRpos())) + return left_orth!(t; kind, alg_qr, trunc) + elseif kind == :polar + return left_orth!(t; kind, trunc) + else + throw(ArgumentError(lazy"Invalid algorithm: $alg")) + end +end +# fallback to MatrixAlgebraKit version +_leftorth!(t, alg; kwargs...) = left_orth!(t; alg, kwargs...) + +function leftnull!(t::AbstractTensorMap; + alg::Union{QR,QRpos,SVD,SDD,Nothing}=nothing, kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:leftnull!) + trunc = isempty(kwargs) ? nothing : (; kwargs...) + + isnothing(alg) && return left_null!(t; trunc) + + kind = _kindof(alg) + if kind == :svd + return left_null!(t; kind, alg_svd=alg, trunc) + elseif kind == :qr + alg_qr = (; positive=(alg == QRpos())) + return left_null!(t; kind, alg_qr, trunc) + else + throw(ArgumentError(lazy"Invalid `leftnull!` algorithm: $alg")) + end +end + +leftpolar!(t::AbstractTensorMap; kwargs...) = left_polar!(t; kwargs...) + +function rightorth!(t::AbstractTensorMap; + alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar,Nothing}=nothing, kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightorth!) + trunc = isempty(kwargs) ? nothing : (; kwargs...) + + isnothing(alg) && return right_orth!(t; trunc) + + if alg == RQ() || alg == RQpos() + _reverse!(t; dims=1) + L, Q = right_orth!(t; kind=:lq, alg_lq=(; positive=alg == RQpos()), trunc) + _reverse!(Q; dims=1) + _reverse!(L) + return L, Q + end + + kind = _kindof(alg) + if kind == :svd + return right_orth!(t; kind, alg_svd=alg, trunc) + elseif kind == :lq + alg_lq = (; positive=(alg == LQpos())) + return right_orth!(t; kind, alg_lq, trunc) + elseif kind == :polar + return right_orth!(t; kind, trunc) + else + throw(ArgumentError(lazy"Invalid `rightorth!` algorithm: $alg")) + end +end + +function rightnull!(t::AbstractTensorMap; + alg::Union{LQ,LQpos,SVD,SDD,Nothing}=nothing, kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightnull!) + trunc = isempty(kwargs) ? nothing : (; kwargs...) + + isnothing(alg) && return right_null!(t; trunc) + + kind = _kindof(alg) + if kind == :svd + return right_null!(t; kind, alg_svd=alg, trunc) + elseif kind == :lq + alg_lq = (; positive=(alg == LQpos())) + return right_null!(t; kind, alg_lq, trunc) + else + throw(ArgumentError(lazy"Invalid `rightnull!` algorithm: $alg")) + end +end + +rightpolar!(t::AbstractTensorMap; kwargs...) = right_polar!(t; kwargs...) + +# Eigenvalue decomposition +# ------------------------ +eigh!(t::AbstractTensorMap) = eigh_full!(t) +eig!(t::AbstractTensorMap) = eig_full!(t) +eigen!(t::AbstractTensorMap) = ishermitian(t) ? eigh!(t) : eig!(t) + +# Singular value decomposition +# ---------------------------- +function tsvd!(t::AbstractTensorMap; trunc=notrunc(), p=nothing, kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) + isnothing(p) || Base.depwarn("p is no longer supported", :tsvd!) + + if trunc == notrunc() + return svd_compact!(t; kwargs...) + else + return svd_trunc!(t; trunc, kwargs...) + end +end diff --git a/src/tensors/factorizations/interface.jl b/src/tensors/factorizations/interface.jl new file mode 100644 index 000000000..821bc15c3 --- /dev/null +++ b/src/tensors/factorizations/interface.jl @@ -0,0 +1,242 @@ +@doc """ + tsvd(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; + trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD()) + -> U, S, V, ϵ + tsvd!(t::AbstractTensorMap, trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD()) + -> U, S, V, ϵ + +Compute the (possibly truncated) singular value decomposition such that +`norm(permute(t, (leftind, rightind)) - U * S * V) ≈ ϵ`, where `ϵ` thus represents the truncation error. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in +`t` to be destroyed/overwritten, by using `tsvd!(t, trunc = notrunc(), p = 2)`. + +A truncation parameter `trunc` can be specified for the new internal dimension, in which +case a truncated singular value decomposition will be computed. Choices are: +* `notrunc()`: no truncation (default); +* `truncerr(η::Real)`: truncates such that the p-norm of the truncated singular values is + smaller than `η`; +* `truncdim(χ::Int)`: truncates such that the equivalent total dimension of the internal + vector space is no larger than `χ`; +* `truncspace(V)`: truncates such that the dimension of the internal vector space is + smaller than that of `V` in any sector. +* `truncbelow(η::Real)`: truncates such that every singular value is larger then `η` ; + +Truncation options can also be combined using `&`, i.e. `truncbelow(η) & truncdim(χ)` will +choose the truncation space such that every singular value is larger than `η`, and the +equivalent total dimension of the internal vector space is no larger than `χ`. + +The method `tsvd` also returns the truncation error `ϵ`, computed as the `p` norm of the +singular values that were truncated. + +THe keyword `alg` can be equal to `SVD()` or `SDD()`, corresponding to the underlying LAPACK +algorithm that computes the decomposition (`_gesvd` or `_gesdd`). + +Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)` +is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`. +""" tsvd, tsvd! + +@doc """ + eig(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> D, V + eig!(t::AbstractTensorMap; kwargs...) -> D, V + +Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. +The function `eig` assumes that the linear map is not hermitian and returns type stable +complex valued `D` and `V` tensors for both real and complex valued `t`. See `eigh` for +hermitian linear maps + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in +`t` to be destroyed/overwritten, by using `eig!(t)`. Note that the permuted tensor on +which `eig!` is called should have equal domain and codomain, as otherwise the eigenvalue +decomposition is meaningless and cannot satisfy +``` +permute(t, (leftind, rightind)) * V = V * D +``` + +Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense +matrices. See the corresponding documentation for more information. + +See also `eigen` and `eigh`. +""" eig + +@doc """ + eigh(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> D, V + eigh!(t::AbstractTensorMap; kwargs...) -> D, V + +Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. +The function `eigh` assumes that the linear map is hermitian and `D` and `V` tensors with +the same `scalartype` as `t`. See `eig` and `eigen` for non-hermitian tensors. Hermiticity +requires that the tensor acts on inner product spaces, and the current implementation +requires `InnerProductStyle(t) === EuclideanInnerProduct()`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in +`t` to be destroyed/overwritten, by using `eigh!(t)`. Note that the permuted tensor on +which `eigh!` is called should have equal domain and codomain, as otherwise the eigenvalue +decomposition is meaningless and cannot satisfy +``` +permute(t, (leftind, rightind)) * V = V * D +``` + +See also `eigen` and `eig`. +""" eigh, eigh! + +@doc """ + leftorth(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple; + alg::OrthogonalFactorizationAlgorithm = QRpos()) -> Q, R + +Create orthonormal basis `Q` for indices in `leftind`, and remainder `R` such that +`permute(t, (leftind, rightind)) = Q*R`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` +to be destroyed/overwritten, by using `leftorth!(t, alg = QRpos())`. + +Different algorithms are available, namely `QR()`, `QRpos()`, `SVD()` and `Polar()`. `QR()` +and `QRpos()` use a standard QR decomposition, producing an upper triangular matrix `R`. +`Polar()` produces a Hermitian and positive semidefinite `R`. `QRpos()` corrects the +standard QR decomposition such that the diagonal elements of `R` are positive. Only +`QRpos()` and `Polar()` are unique (no residual freedom) so that they always return the same +result for the same input tensor `t`. + +Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and +`leftorth(!)` is currently only implemented for + `InnerProductStyle(t) === EuclideanInnerProduct()`. +""" leftorth, leftorth! + +@doc """ + rightorth(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; + alg::OrthogonalFactorizationAlgorithm = LQpos()) -> L, Q + rightorth!(t::AbstractTensorMap; alg) -> L, Q + +Create orthonormal basis `Q` for indices in `rightind`, and remainder `L` such that +`permute(t, (leftind, rightind)) = L*Q`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` +to be destroyed/overwritten, by using `rightorth!(t, alg = LQpos())`. + +Different algorithms are available, namely `LQ()`, `LQpos()`, `RQ()`, `RQpos()`, `SVD()` and +`Polar()`. `LQ()` and `LQpos()` produce a lower triangular matrix `L` and are computed using +a QR decomposition of the transpose. `RQ()` and `RQpos()` produce an upper triangular +remainder `L` and only works if the total left dimension is smaller than or equal to the +total right dimension. `LQpos()` and `RQpos()` add an additional correction such that the +diagonal elements of `L` are positive. `Polar()` produces a Hermitian and positive +semidefinite `L`. Only `LQpos()`, `RQpos()` and `Polar()` are unique (no residual freedom) +so that they always return the same result for the same input tensor `t`. + +Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and +`rightorth(!)` is currently only implemented for +`InnerProductStyle(t) === EuclideanInnerProduct()`. +""" rightorth, rightorth! + +@doc """ + leftnull(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; + alg::OrthogonalFactorizationAlgorithm = QRpos()) -> N + leftnull!(t::AbstractTensorMap; alg) -> N + +Create orthonormal basis for the orthogonal complement of the support of the indices in +`leftind`, such that `N' * permute(t, (leftind, rightind)) = 0`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` +to be destroyed/overwritten, by using `leftnull!(t, alg = QRpos())`. + +Different algorithms are available, namely `QR()` (or equivalently, `QRpos()`), `SVD()` and +`SDD()`. The first assumes that the matrix is full rank and requires `iszero(atol)` and +`iszero(rtol)`. With `SVD()` and `SDD()`, `rightnull` will use the corresponding singular +value decomposition, and one can specify an absolute or relative tolerance for which +singular values are to be considered zero, where `max(atol, norm(t)*rtol)` is used as upper +bound. + +Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and +`leftnull(!)` is currently only implemented for +`InnerProductStyle(t) === EuclideanInnerProduct()`. +""" leftnull, leftnull! + +@doc """ + rightnull(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; + alg::OrthogonalFactorizationAlgorithm = LQ(), + atol::Real = 0.0, + rtol::Real = eps(real(float(one(scalartype(t)))))*iszero(atol)) -> N + rightnull!(t::AbstractTensorMap; alg, atol, rtol) + +Create orthonormal basis for the orthogonal complement of the support of the indices in +`rightind`, such that `permute(t, (leftind, rightind))*N' = 0`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` +to be destroyed/overwritten, by using `rightnull!(t, alg = LQpos())`. + +Different algorithms are available, namely `LQ()` (or equivalently, `LQpos`), `SVD()` and +`SDD()`. The first assumes that the matrix is full rank and requires `iszero(atol)` and +`iszero(rtol)`. With `SVD()` and `SDD()`, `rightnull` will use the corresponding singular +value decomposition, and one can specify an absolute or relative tolerance for which +singular values are to be considered zero, where `max(atol, norm(t)*rtol)` is used as upper +bound. + +Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and +`rightnull(!)` is currently only implemented for +`InnerProductStyle(t) === EuclideanInnerProduct()`. +""" rightnull, rightnull! + +@doc """ + leftpolar(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> W, P + leftpolar!(t::AbstractTensorMap; kwargs...) -> W, P + +Compute the polar decomposition of tensor `t` as linear map from `rightind` to `leftind`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in +`t` to be destroyed/overwritten, by using `eigh!(t)`. + +See also [`rightpolar(!)`](@ref rightpolar). + +""" leftpolar, leftpolar! + +@doc """ + eigen(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> D, V + eigen!(t::AbstractTensorMap; kwargs...) -> D, V + +Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` +to be destroyed/overwritten, by using `eigen!(t)`. Note that the permuted tensor on which +`eigen!` is called should have equal domain and codomain, as otherwise the eigenvalue +decomposition is meaningless and cannot satisfy +``` +permute(t, (leftind, rightind)) * V = V * D +``` + +Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense +matrices. See the corresponding documentation for more information. + +See also [`eig(!)`](@ref eig) and [`eigh(!)`](@ref) +""" eigen(::AbstractTensorMap), eigen!(::AbstractTensorMap) + +for f in + (:tsvd, :eig, :eigh, :eigen, :leftorth, :rightorth, :leftpolar, :rightpolar, :leftnull, + :rightnull) + f! = Symbol(f, :!) + @eval function $f(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + tcopy = permutedcopy_oftype(t, factorisation_scalartype($f, t), p) + return $f!(tcopy; kwargs...) + end + @eval function $f(t::AbstractTensorMap; kwargs...) + tcopy = copy_oftype(t, factorisation_scalartype($f, t)) + return $f!(tcopy; kwargs...) + end +end + +function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...) + tcopy = copy_oftype(t, factorisation_scalartype(eigen, t)) + return LinearAlgebra.eigvals!(tcopy; kwargs...) +end + +function LinearAlgebra.svdvals(t::AbstractTensorMap) + tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t)) + return LinearAlgebra.svdvals!(tcopy) +end diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl new file mode 100644 index 000000000..605428f84 --- /dev/null +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -0,0 +1,508 @@ +# Algorithm selection +# ------------------- +for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, :svd_full, + :svd_compact, :svd_vals, :svd_trunc) + @eval function copy_input(::typeof($f), t::AbstractTensorMap{<:BlasFloat}) + T = factorisation_scalartype($f, t) + return copy_oftype(t, T) + end + f! = Symbol(f, :!) + # TODO: can we move this to MAK? + @eval function select_algorithm(::typeof($f!), t::AbstractTensorMap, alg::Alg=nothing; + kwargs...) where {Alg} + return select_algorithm($f!, typeof(t), alg; kwargs...) + end + @eval function select_algorithm(::typeof($f!), ::Type{T}, alg::Alg=nothing; + kwargs...) where {T<:AbstractTensorMap,Alg} + return select_algorithm($f!, blocktype(T), alg; kwargs...) + end +end + +for f in (:qr, :lq, :svd, :eig, :eigh, :polar) + default_f_algorithm = Symbol(:default_, f, :_algorithm) + @eval function $default_f_algorithm(::Type{T}; kwargs...) where {T<:AbstractTensorMap} + return $default_f_algorithm(blocktype(T); kwargs...) + end +end + +function _select_truncation(f, ::AbstractTensorMap, + trunc::MatrixAlgebraKit.TruncationStrategy) + return trunc +end +function _select_truncation(::typeof(left_null!), ::AbstractTensorMap, trunc::NamedTuple) + return MatrixAlgebraKit.null_truncation_strategy(; trunc...) +end + +# Generic Implementations +# ----------------------_ +for f! in (:qr_compact!, :qr_full!, + :lq_compact!, :lq_full!, + :eig_full!, :eigh_full!, + :svd_compact!, :svd_full!, + :left_polar!, :left_orth_polar!, :right_polar!, :right_orth_polar!) + @eval function $f!(t::AbstractTensorMap, F, alg::AbstractAlgorithm) + check_input($f!, t, F) + + foreachblock(t, F...) do _, bs + factors = Base.tail(bs) + factors′ = $f!(first(bs), factors, alg) + # deal with the case where the output is not in-place + for (f′, f) in zip(factors′, factors) + f′ === f || copyto!(f, f′) + end + return nothing + end + + return F + end +end + +# Handle these separately because single N instead of tuple +for f! in (:qr_null!, :lq_null!) + @eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) + check_input($f!, t, N) + + foreachblock(t, N) do _, (b, n) + n′ = $f!(b, n, alg) + # deal with the case where the output is not the same as the input + n === n′ || copyto!(n, n′) + return nothing + end + + return N + end +end + +# Singular value decomposition +# ---------------------------- +const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} +const _T_USVᴴ_diag = Tuple{<:AbstractTensorMap,<:DiagonalTensorMap,<:AbstractTensorMap} + +function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ) + # scalartype checks + @check_scalar U t + @check_scalar S t real + @check_scalar Vᴴ t + + # space checks + V_cod = fuse(codomain(t)) + V_dom = fuse(domain(t)) + @check_space(U, codomain(t) ← V_cod) + @check_space(S, V_cod ← V_dom) + @check_space(Vᴴ, V_dom ← domain(t)) + + return nothing +end + +function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ_diag) + # scalartype checks + @check_scalar U t + @check_scalar S t real + @check_scalar Vᴴ t + + # space checks + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) + @check_space(U, codomain(t) ← V_cod) + @check_space(S, V_cod ← V_dom) + @check_space(Vᴴ, V_dom ← domain(t)) + + return nothing +end + +# TODO: svd_vals + +function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_cod = fuse(codomain(t)) + V_dom = fuse(domain(t)) + U = similar(t, codomain(t) ← V_cod) + S = similar(t, real(scalartype(t)), V_cod ← V_dom) + Vᴴ = similar(t, V_dom ← domain(t)) + return U, S, Vᴴ +end + +function initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, + ::AbstractAlgorithm) + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) + U = similar(t, codomain(t) ← V_cod) + S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) + Vᴴ = similar(t, V_dom ← domain(t)) + return U, S, Vᴴ +end + +function initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, + alg::TruncatedAlgorithm) + return initialize_output(svd_compact!, t, alg.alg) +end + +# TODO: svd_vals + +function svd_trunc!(t::AbstractTensorMap, USVᴴ, alg::TruncatedAlgorithm) + USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) + return truncate!(svd_trunc!, USVᴴ′, alg.trunc) +end + +# Eigenvalue decomposition +# ------------------------ +const _T_DV = Tuple{<:DiagonalTensorMap,<:AbstractTensorMap} + +function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV) + domain(t) == codomain(t) || + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + + # scalartype checks + @check_scalar D t real + @check_scalar V t + + # space checks + V_D = fuse(domain(t)) + @check_space(D, V_D ← V_D) + @check_space(V, codomain(t) ← V_D) + + return nothing +end + +function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV) + domain(t) == codomain(t) || + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + + # scalartype checks + @check_scalar D t complex + @check_scalar V t complex + + # space checks + V_D = fuse(domain(t)) + @check_space(D, V_D ← V_D) + @check_space(V, codomain(t) ← V_D) + + return nothing +end + +function initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_D = fuse(domain(t)) + T = real(scalartype(t)) + D = DiagonalTensorMap{T}(undef, V_D) + V = similar(t, codomain(t) ← V_D) + return D, V +end + +function initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_D = fuse(domain(t)) + Tc = complex(scalartype(t)) + D = DiagonalTensorMap{Tc}(undef, V_D) + V = similar(t, Tc, codomain(t) ← V_D) + return D, V +end + +# QR decomposition +# ---------------- +const _T_QR = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} + +function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR) + # scalartype checks + @check_scalar Q t + @check_scalar R t + + # space checks + V_Q = fuse(codomain(t)) + @check_space(Q, codomain(t) ← V_Q) + @check_space(R, V_Q ← domain(t)) + + return nothing +end + +function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR) + # scalartype checks + @check_scalar Q t + @check_scalar R t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + @check_space(Q, codomain(t) ← V_Q) + @check_space(R, V_Q ← domain(t)) + + return nothing +end + +function check_input(::typeof(qr_null!), t::AbstractTensorMap, N::AbstractTensorMap) + # scalartype checks + @check_scalar N t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(codomain(t)), V_Q) + @check_space(N, codomain(t) ← V_N) + + return nothing +end + +function initialize_output(::typeof(qr_full!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_Q = fuse(codomain(t)) + Q = similar(t, codomain(t) ← V_Q) + R = similar(t, V_Q ← domain(t)) + return Q, R +end + +function initialize_output(::typeof(qr_compact!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + Q = similar(t, codomain(t) ← V_Q) + R = similar(t, V_Q ← domain(t)) + return Q, R +end + +function initialize_output(::typeof(qr_null!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(codomain(t)), V_Q) + N = similar(t, codomain(t) ← V_N) + return N +end + +# LQ decomposition +# ---------------- +const _T_LQ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} + +function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ) + # scalartype checks + @check_scalar L t + @check_scalar Q t + + # space checks + V_Q = fuse(domain(t)) + @check_space(L, codomain(t) ← V_Q) + @check_space(Q, V_Q ← domain(t)) + + return nothing +end + +function check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)::_T_LQ) + # scalartype checks + @check_scalar L t + @check_scalar Q t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + @check_space(L, codomain(t) ← V_Q) + @check_space(Q, V_Q ← domain(t)) + + return nothing +end + +function check_input(::typeof(lq_null!), t::AbstractTensorMap, N) + # scalartype checks + @check_scalar N t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(domain(t)), V_Q) + @check_space(N, V_N ← domain(t)) + + return nothing +end + +function initialize_output(::typeof(lq_full!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_Q = fuse(domain(t)) + L = similar(t, codomain(t) ← V_Q) + Q = similar(t, V_Q ← domain(t)) + return L, Q +end + +function initialize_output(::typeof(lq_compact!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + L = similar(t, codomain(t) ← V_Q) + Q = similar(t, V_Q ← domain(t)) + return L, Q +end + +function initialize_output(::typeof(lq_null!), t::AbstractTensorMap, ::AbstractAlgorithm) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(domain(t)), V_Q) + N = similar(t, V_N ← domain(t)) + return N +end + +# Polar decomposition +# ------------------- +const _T_WP = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} +const _T_PWᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} +using MatrixAlgebraKit: PolarViaSVD + +function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP) + codomain(t) ≿ domain(t) || + throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) + + # scalartype checks + @check_scalar W t + @check_scalar P t + + # space checks + @check_space(W, space(t)) + @check_space(P, domain(t) ← domain(t)) + + return nothing +end + +function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, (W, P)::_T_WP) + codomain(t) ≿ domain(t) || + throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) + + # scalartype checks + @check_scalar W t + @check_scalar P t + + # space checks + VW = fuse(domain(t)) + @check_space(W, codomain(t) ← VW) + @check_space(P, VW ← domain(t)) + + return nothing +end + +function initialize_output(::typeof(left_polar!), t::AbstractTensorMap, ::AbstractAlgorithm) + W = similar(t, space(t)) + P = similar(t, domain(t) ← domain(t)) + return W, P +end + +function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ) + domain(t) ≿ codomain(t) || + throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) + + # scalartype checks + @check_scalar P t + @check_scalar Wᴴ t + + # space checks + @check_space(P, codomain(t) ← codomain(t)) + @check_space(Wᴴ, space(t)) + + return nothing +end + +function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ) + domain(t) ≿ codomain(t) || + throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) + + # scalartype checks + @check_scalar P t + @check_scalar Wᴴ t + + # space checks + VW = fuse(codomain(t)) + @check_space(P, codomain(t) ← VW) + @check_space(Wᴴ, VW ← domain(t)) + + return nothing +end + +function initialize_output(::typeof(right_polar!), t::AbstractTensorMap, + ::AbstractAlgorithm) + P = similar(t, codomain(t) ← codomain(t)) + Wᴴ = similar(t, space(t)) + return Wᴴ, P +end + +# Needed to get algorithm selection to behave +function left_orth_polar!(t::AbstractTensorMap, VC, alg) + alg′ = select_algorithm(left_polar!, t, alg) + return left_orth_polar!(t, VC, alg′) +end +function right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg) + alg′ = select_algorithm(right_polar!, t, alg) + return right_orth_polar!(t, CVᴴ, alg′) +end + +# Orthogonalization +# ----------------- +const _T_VC = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} +const _T_CVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} + +function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC) + # scalartype checks + @check_scalar V t + isnothing(C) || @check_scalar C t + + # space checks + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + @check_space(V, codomain(t) ← V_C) + isnothing(C) || @check_space(C, V_C ← domain(t)) + + return nothing +end + +function check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)::_T_CVᴴ) + # scalartype checks + isnothing(C) || @check_scalar C t + @check_scalar Vᴴ t + + # space checks + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + isnothing(C) || @check_space(C, codomain(t) ← V_C) + @check_space(Vᴴ, V_C ← domain(t)) + + return nothing +end + +function initialize_output(::typeof(left_orth!), t::AbstractTensorMap) + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + V = similar(t, codomain(t) ← V_C) + C = similar(t, V_C ← domain(t)) + return V, C +end + +function initialize_output(::typeof(right_orth!), t::AbstractTensorMap) + V_C = infimum(fuse(codomain(t)), fuse(domain(t))) + C = similar(t, codomain(t) ← V_C) + Vᴴ = similar(t, V_C ← domain(t)) + return C, Vᴴ +end + +# Nullspace +# --------- +function check_input(::typeof(left_null!), t::AbstractTensorMap, N) + # scalartype checks + @check_scalar N t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(codomain(t)), V_Q) + @check_space(N, codomain(t) ← V_N) + + return nothing +end + +function check_input(::typeof(right_null!), t::AbstractTensorMap, N) + @check_scalar N t + + # space checks + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(domain(t)), V_Q) + @check_space(N, V_N ← domain(t)) + + return nothing +end + +function initialize_output(::typeof(left_null!), t::AbstractTensorMap) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(codomain(t)), V_Q) + N = similar(t, codomain(t) ← V_N) + return N +end + +function initialize_output(::typeof(right_null!), t::AbstractTensorMap) + V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) + V_N = ⊖(fuse(domain(t)), V_Q) + N = similar(t, V_N ← domain(t)) + return N +end + +for f! in (:left_null_svd!, :right_null_svd!) + @eval function $f!(t::AbstractTensorMap, N, alg, ::Nothing=nothing) + foreachblock(t, N) do _, (b, n) + n′ = $f!(b, n, alg) + # deal with the case where the output is not the same as the input + n === n′ || copyto!(n, n′) + return nothing + end + + return N + end +end diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl new file mode 100644 index 000000000..e9806645a --- /dev/null +++ b/src/tensors/factorizations/truncation.jl @@ -0,0 +1,270 @@ +# truncation.jl +# +# Implements truncation schemes for truncating a tensor with svd, leftorth or rightorth + +notrunc() = NoTruncation() +# deprecate +const TruncationScheme = TruncationStrategy + +struct TruncationError{T<:Real} <: TruncationStrategy + ϵ::T + p::Real +end +truncerr(epsilon::Real, p::Real=2) = TruncationError(epsilon, p) + +# struct TruncationDimension <: TruncationScheme +# dim::Int +# end +@deprecate truncdim(d::Int) truncrank(d) + +struct TruncationSpace{S<:ElementarySpace} <: TruncationStrategy + space::S +end +truncspace(space::ElementarySpace) = TruncationSpace(space) + +struct TruncationCutoff{T<:Real} <: TruncationStrategy + ϵ::T + add_back::Int +end +@deprecate truncbelow(ϵ::Real, add_back::Int=0) begin + add_back == 0 || @warn "add_back is ignored" + trunctol(ϵ) +end +# truncbelow(epsilon::Real, add_back::Int=0) = TruncationCutoff(epsilon, add_back) + +# Compute the total truncation error given truncation dimensions +function _compute_truncerr(Σdata, truncdim, p=2) + I = keytype(Σdata) + S = scalartype(valtype(Σdata)) + return TensorKit._norm((c => @view(v[(get(truncdim, c, 0) + 1):end]) + for (c, v) in Σdata), + p, zero(S)) +end + +# Compute truncation dimensions +# function _compute_truncdim(Σdata, ::NoTruncation, p=2) +# I = keytype(Σdata) +# truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) +# return truncdim +# end +# function _compute_truncdim(Σdata, trunc::TruncationDimension, p=2) +# I = keytype(Σdata) +# truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) +# while sum(dim(c) * d for (c, d) in truncdim) > trunc.dim +# cmin = _findnexttruncvalue(Σdata, truncdim, p) +# isnothing(cmin) && break +# truncdim[cmin] -= 1 +# end +# return truncdim +# end +function _compute_truncdim(Σdata, trunc::TruncationSpace, p=2) + I = keytype(Σdata) + truncdim = SectorDict{I,Int}(c => min(length(v), dim(trunc.space, c)) + for (c, v) in Σdata) + return truncdim +end + +# function _compute_truncdim(Σdata, trunc::TruncationCutoff, p=2) +# I = keytype(Σdata) +# truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) +# for (c, v) in Σdata +# newdim = findlast(Base.Fix2(>, trunc.ϵ), v) +# if newdim === nothing +# truncdim[c] = 0 +# else +# truncdim[c] = newdim +# end +# end +# for i in 1:(trunc.add_back) +# cmax = _findnextgrowvalue(Σdata, truncdim, p) +# isnothing(cmax) && break +# truncdim[cmax] += 1 +# end +# return truncdim +# end + +# Combine truncations +# struct MultipleTruncation{T<:Tuple{Vararg{TruncationScheme}}} <: TruncationScheme +# truncations::T +# end +# function Base.:&(a::MultipleTruncation, b::MultipleTruncation) +# return MultipleTruncation((a.truncations..., b.truncations...)) +# end +# function Base.:&(a::MultipleTruncation, b::TruncationScheme) +# return MultipleTruncation((a.truncations..., b)) +# end +# function Base.:&(a::TruncationScheme, b::MultipleTruncation) +# return MultipleTruncation((a, b.truncations...)) +# end +# Base.:&(a::TruncationScheme, b::TruncationScheme) = MultipleTruncation((a, b)) + +# function _compute_truncdim(Σdata, trunc::MultipleTruncation, p::Real=2) +# truncdim = _compute_truncdim(Σdata, trunc.truncations[1], p) +# for k in 2:length(trunc.truncations) +# truncdimₖ = _compute_truncdim(Σdata, trunc.truncations[k], p) +# for (c, d) in truncdim +# truncdim[c] = min(d, truncdimₖ[c]) +# end +# end +# return truncdim +# end + +# auxiliary function +function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector} + # early return + (isempty(S) || all(iszero, values(truncdim))) && return nothing + σmin, imin = findmin(keys(truncdim)) do c + d = truncdim[c] + return S[c][d] + end + return σmin, keys(truncdim)[imin] +end + +function _findnextgrowvalue(Σdata, truncdim::SectorDict{I,Int}, p::Real) where {I<:Sector} + istruncated = SectorDict{I,Bool}(c => (d < length(Σdata[c])) for (c, d) in truncdim) + # early return + (isempty(Σdata) || !any(values(istruncated))) && return nothing + + # find some suitable starting candidate + cmax = findfirst(istruncated) + σmax = Σdata[cmax][truncdim[cmax] + 1] + + # find the actual maximal singular value + for (c, σs) in Σdata + if istruncated[c] + σ = σs[truncdim[c] + 1] + if σ > σmax + cmax, σmax = c, σ + end + end + end + return cmax +end + +# Truncation +# ---------- + +function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ)::_T_USVᴴ, strategy::TruncationStrategy) + ind = findtruncated_sorted(diagview(S), strategy) + V_truncated = spacetype(S)(c => length(I) for (c, I) in ind) + + Ũ = similar(U, codomain(U) ← V_truncated) + for (c, b) in blocks(Ũ) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b, @view(block(U, c)[:, I])) + end + + S̃ = DiagonalTensorMap{scalartype(S)}(undef, V_truncated) + for (c, b) in blocks(S̃) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b.diag, @view(block(S, c).diag[I])) + end + + Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) + for (c, b) in blocks(Ṽᴴ) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b, @view(block(Vᴴ, c)[I, :])) + end + + return Ũ, S̃, Ṽᴴ +end + +function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove) + atol = if strategy.rtol > 0 + max(strategy.atol, _norm(S, strategy.p) * strategy.rtol) + else + strategy.atol + end + findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol)) + return SectorDict(c => findtrunc(d) for (c, d) in Sd) +end + +function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepBelow) + atol = if strategy.rtol > 0 + max(strategy.atol, _norm(S, strategy.p) * strategy.rtol) + else + strategy.atol + end + findtrunc = Base.Fix2(findtruncated_sorted, truncabove(atol)) + return SectorDict(c => findtrunc(d) for (c, d) in Sd) +end + +function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError) + I = keytype(Sd) + S = real(scalartype(valtype(Sd))) + truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd) + while true + next = _findnexttruncvalue(Sd, truncdim) + isnothing(next) && break + σmin, cmin = next + truncdim[cmin] -= 1 + err = _compute_truncerr(Sd, truncdim, strategy.p) + if err > strategy.ϵ + truncdim[cmin] += 1 + break + end + if truncdim[cmin] == 0 + delete!(truncdim, cmin) + end + end + return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) +end + +function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted) + @assert strategy.by === abs && strategy.rev == true "Not implemented" + I = keytype(Sd) + S = real(scalartype(valtype(Sd))) + truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd) + totaldim = sum(dim(c) * d for (c, d) in truncdim; init=0) + while true + next = _findnexttruncvalue(Sd, truncdim) + isnothing(next) && break + _, cmin = next + truncdim[cmin] -= 1 + totaldim -= dim(cmin) + if totaldim < strategy.howmany + truncdim[cmin] += 1 + break + end + if truncdim[cmin] == 0 + delete!(truncdim, cmin) + end + end + return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) +end + +function findtruncated_sorted(Sd::SectorDict, strategy::TruncationSpace) + I = keytype(Sd) + return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(min(length(d), + dim(strategy.space, c))) + for (c, d) in Sd) +end + +function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepFiltered) + return SectorDict(c => findtruncated_sorted(d, strategy) for (c, d) in Sd) +end + +function findtruncated_sorted(Sd::SectorDict, strategy::TruncationIntersection) + inds = map(Base.Fix1(findtruncated_sorted, Sd), strategy) + return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...) + for c in intersect(map(keys, inds)...)) +end + +function MatrixAlgebraKit.truncate!(::typeof(left_null!), + (U, S)::Tuple{<:AbstractTensorMap, + <:AbstractTensorMap}, + strategy::MatrixAlgebraKit.TruncationStrategy) + extended_S = SectorDict(c => vcat(MatrixAlgebraKit.diagview(b), + zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) + for (c, b) in blocks(S)) + ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) + V_truncated = spacetype(S)(c => length(axes(b, 1)[ind[c]]) for (c, b) in blocks(S)) + Ũ = similar(U, codomain(U) ← V_truncated) + for (c, b) in blocks(Ũ) + copy!(b, @view(block(U, c)[:, ind[c]])) + end + return Ũ +end diff --git a/src/tensors/factorizations/utility.jl b/src/tensors/factorizations/utility.jl new file mode 100644 index 000000000..e23fb8b73 --- /dev/null +++ b/src/tensors/factorizations/utility.jl @@ -0,0 +1,29 @@ +# convenience to set default +macro check_space(x, V) + return esc(:($MatrixAlgebraKit.@check_size($x, $V, $space))) +end +macro check_scalar(x, y, op=:identity, eltype=:scalartype) + return esc(:($MatrixAlgebraKit.@check_scalar($x, $y, $op, $eltype))) +end + +function factorisation_scalartype(t::AbstractTensorMap) + T = scalartype(t) + return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) +end +factorisation_scalartype(f, t) = factorisation_scalartype(t) + +function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple) + return permute!(similar(t, T, permute(space(t), p)), t, p) +end +function copy_oftype(t::AbstractTensorMap, T::Type{<:Number}) + return copy!(similar(t, T), t) +end + +function _reverse!(t::AbstractTensorMap; dims=:) + for (c, b) in blocks(t) + reverse!(b; dims) + end + return t +end + +diagview(t::AbstractTensorMap) = SectorDict(c => diagview(b) for (c, b) in blocks(t)) diff --git a/src/tensors/matrixalgebrakit.jl b/src/tensors/matrixalgebrakit.jl deleted file mode 100644 index 4481267df..000000000 --- a/src/tensors/matrixalgebrakit.jl +++ /dev/null @@ -1,715 +0,0 @@ -# convenience to set default -macro check_space(x, V) - return esc(:($MatrixAlgebraKit.@check_size($x, $V, $space))) -end -macro check_scalar(x, y, op=:identity, eltype=:scalartype) - return esc(:($MatrixAlgebraKit.@check_scalar($x, $y, $op, $eltype))) -end - -# Generic -# ------- -for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, :svd_full, - :svd_compact, :svd_vals, :svd_trunc) - @eval function MatrixAlgebraKit.copy_input(::typeof($f), - t::AbstractTensorMap{<:BlasFloat}) - T = factorisation_scalartype($f, t) - return copy_oftype(t, T) - end - f! = Symbol(f, :!) - @eval function MatrixAlgebraKit.select_algorithm(::typeof($f!), t::AbstractTensorMap, - alg::Alg=nothing; - kwargs...) where {Alg} - return MatrixAlgebraKit.select_algorithm($f!, typeof(t), alg; kwargs...) - end - @eval function MatrixAlgebraKit.select_algorithm(::typeof($f!), ::Type{T}, - alg::Alg=nothing; - scheduler=default_blockscheduler(T), - kwargs...) where {T<:AbstractTensorMap, - Alg} - mat_alg = MatrixAlgebraKit.select_algorithm($f!, blocktype(T), alg; kwargs...) - return BlockAlgorithm(mat_alg, scheduler) - end -end - -for f in (:qr, :lq, :svd, :eig, :eigh, :polar) - default_f_algorithm = Symbol(:default_, f, :_algorithm) - @eval function MatrixAlgebraKit.$default_f_algorithm(::Type{T}; - scheduler=default_blockscheduler(T), - kwargs...) where {T<:AbstractTensorMap} - return BlockAlgorithm(MatrixAlgebraKit.$default_f_algorithm(blocktype(T); - kwargs...), - scheduler) - end -end - -# TODO: move to MatrixAlgebraKit? -macro check_eltype(x, y, f=:identity, g=:eltype) - msg = "unexpected scalar type: " - msg *= string(g) * "(" * string(x) * ") != " - if f == :identity - msg *= string(g) * "(" * string(y) * ")" - else - msg *= string(f) * "(" * string(y) * ")" - end - return esc(:($g($x) == $f($g($y)) || throw(ArgumentError($msg)))) -end - -function _select_truncation(f, ::AbstractTensorMap, - trunc::MatrixAlgebraKit.TruncationStrategy) - return trunc -end -function _select_truncation(::typeof(left_null!), ::AbstractTensorMap, trunc::NamedTuple) - return MatrixAlgebraKit.null_truncation_strategy(; trunc...) -end - -function MatrixAlgebraKit.diagview(t::AbstractTensorMap) - return SectorDict(c => MatrixAlgebraKit.diagview(b) for (c, b) in blocks(t)) -end - -# Singular value decomposition -# ---------------------------- -const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} -const _T_USVᴴ_diag = Tuple{<:AbstractTensorMap,<:DiagonalTensorMap,<:AbstractTensorMap} - -function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, - (U, S, Vᴴ)::_T_USVᴴ) - # scalartype checks - @check_scalar U t - @check_scalar S t real - @check_scalar Vᴴ t - - # space checks - V_cod = fuse(codomain(t)) - V_dom = fuse(domain(t)) - @check_space(U, codomain(t) ← V_cod) - @check_space(S, V_cod ← V_dom) - @check_space(Vᴴ, V_dom ← domain(t)) - - return nothing -end - -function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorMap, - (U, S, Vᴴ)::_T_USVᴴ_diag) - # scalartype checks - @check_eltype U t - @check_eltype S t real - @check_eltype Vᴴ t - - # space checks - V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(U, codomain(t) ← V_cod) - @check_space(S, V_cod ← V_dom) - @check_space(Vᴴ, V_dom ← domain(t)) - - return nothing -end - -# TODO: svd_vals - -function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_cod = fuse(codomain(t)) - V_dom = fuse(domain(t)) - U = similar(t, codomain(t) ← V_cod) - S = similar(t, real(scalartype(t)), V_cod ← V_dom) - Vᴴ = similar(t, V_dom ← domain(t)) - return U, S, Vᴴ -end - -function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) - U = similar(t, codomain(t) ← V_cod) - S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) - Vᴴ = similar(t, V_dom ← domain(t)) - return U, S, Vᴴ -end - -function MatrixAlgebraKit.initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, - alg::MatrixAlgebraKit.AbstractAlgorithm) - return MatrixAlgebraKit.initialize_output(svd_compact!, t, alg) -end - -# TODO: svd_vals - -function MatrixAlgebraKit.svd_full!(t::AbstractTensorMap, (U, S, Vᴴ), - alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(svd_full!, t, (U, S, Vᴴ)) - - foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ) - if isempty(b) # TODO: remove once MatrixAlgebraKit supports empty matrices - MatrixAlgebraKit.one!(length(u) > 0 ? u : vᴴ) - zerovector!(s) - else - u′, s′, vᴴ′ = MatrixAlgebraKit.svd_full!(b, (u, s, vᴴ), alg.alg) - # deal with the case where the output is not the same as the input - u === u′ || copyto!(u, u′) - s === s′ || copyto!(s, s′) - vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′) - end - return nothing - end - - return U, S, Vᴴ -end - -function MatrixAlgebraKit.svd_compact!(t::AbstractTensorMap, (U, S, Vᴴ), - alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(svd_compact!, t, (U, S, Vᴴ)) - - foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ) - u′, s′, vᴴ′ = svd_compact!(b, (u, s, vᴴ), alg.alg) - # deal with the case where the output is not the same as the input - u === u′ || copyto!(u, u′) - s === s′ || copyto!(s, s′) - vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′) - return nothing - end - - return U, S, Vᴴ -end - -function MatrixAlgebraKit.svd_trunc!(t::AbstractTensorMap, USVᴴ, - alg::MatrixAlgebraKit.TruncatedAlgorithm) - USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) - return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ′, alg.trunc) -end - -# Eigenvalue decomposition -# ------------------------ -const _T_DV = Tuple{<:DiagonalTensorMap,<:AbstractTensorMap} -function MatrixAlgebraKit.check_input(::typeof(eigh_full!), t::AbstractTensorMap, - (D, V)::_T_DV) - domain(t) == codomain(t) || - throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - - # scalartype checks - @check_scalar D t real - @check_scalar V t - - # space checks - V_D = fuse(domain(t)) - @check_space(D, V_D ← V_D) - @check_space(V, codomain(t) ← V_D) - - return nothing -end - -function MatrixAlgebraKit.check_input(::typeof(eig_full!), t::AbstractTensorMap, - (D, V)::_T_DV) - domain(t) == codomain(t) || - throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - - # scalartype checks - @check_scalar D t complex - @check_scalar V t complex - - # space checks - V_D = fuse(domain(t)) - @check_space(D, V_D ← V_D) - @check_space(V, codomain(t) ← V_D) - - return nothing -end - -function MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_D = fuse(domain(t)) - T = real(scalartype(t)) - D = DiagonalTensorMap{T}(undef, V_D) - V = similar(t, codomain(t) ← V_D) - return D, V -end - -function MatrixAlgebraKit.initialize_output(::typeof(eig_full!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_D = fuse(domain(t)) - Tc = complex(scalartype(t)) - D = DiagonalTensorMap{Tc}(undef, V_D) - V = similar(t, Tc, codomain(t) ← V_D) - return D, V -end - -for f in (:eigh_full!, :eig_full!) - @eval function MatrixAlgebraKit.$f(t::AbstractTensorMap, (D, V), - alg::BlockAlgorithm) - MatrixAlgebraKit.check_input($f, t, (D, V)) - - foreachblock(t, D, V; alg.scheduler) do _, (b, d, v) - d′, v′ = $f(b, (d, v), alg.alg) - # deal with the case where the output is not the same as the input - d === d′ || copyto!(d, d′) - v === v′ || copyto!(v, v′) - return nothing - end - - return D, V - end -end - -# QR decomposition -# ---------------- -function MatrixAlgebraKit.check_input(::typeof(qr_full!), t::AbstractTensorMap, - (Q, - R)::Tuple{<:AbstractTensorMap,<:AbstractTensorMap}) - # scalartype checks - @check_scalar Q t - @check_scalar R t - - # space checks - V_Q = fuse(codomain(t)) - @check_space(Q, codomain(t) ← V_Q) - @check_space(R, V_Q ← domain(t)) - - return nothing -end - -function MatrixAlgebraKit.check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)) - # scalartype checks - @check_scalar Q t - @check_scalar R t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(Q, codomain(t) ← V_Q) - @check_space(R, V_Q ← domain(t)) - - return nothing -end - -function MatrixAlgebraKit.check_input(::typeof(qr_null!), t::AbstractTensorMap, - N::AbstractTensorMap) - # scalartype checks - @check_scalar N t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(codomain(t)), V_Q) - @check_space(N, codomain(t) ← V_N) - - return nothing -end - -function MatrixAlgebraKit.initialize_output(::typeof(qr_full!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_Q = fuse(codomain(t)) - Q = similar(t, codomain(t) ← V_Q) - R = similar(t, V_Q ← domain(t)) - return Q, R -end - -function MatrixAlgebraKit.initialize_output(::typeof(qr_compact!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - Q = similar(t, codomain(t) ← V_Q) - R = similar(t, V_Q ← domain(t)) - return Q, R -end - -function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(codomain(t)), V_Q) - N = similar(t, codomain(t) ← V_N) - return N -end - -function MatrixAlgebraKit.qr_full!(t::AbstractTensorMap, (Q, R), - alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(qr_full!, t, (Q, R)) - - foreachblock(t, Q, R; alg.scheduler) do _, (b, q, r) - q′, r′ = qr_full!(b, (q, r), alg.alg) - # deal with the case where the output is not the same as the input - q === q′ || copyto!(q, q′) - r === r′ || copyto!(r, r′) - return nothing - end - - return Q, R -end - -function MatrixAlgebraKit.qr_compact!(t::AbstractTensorMap, (Q, R), - alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(qr_compact!, t, (Q, R)) - - foreachblock(t, Q, R; alg.scheduler) do _, (b, q, r) - q′, r′ = qr_compact!(b, (q, r), alg.alg) - # deal with the case where the output is not the same as the input - q === q′ || copyto!(q, q′) - r === r′ || copyto!(r, r′) - return nothing - end - - return Q, R -end - -function MatrixAlgebraKit.qr_null!(t::AbstractTensorMap, N, alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(qr_null!, t, N) - - foreachblock(t, N; alg.scheduler) do _, (b, n) - n′ = qr_null!(b, n, alg.alg) - # deal with the case where the output is not the same as the input - n === n′ || copyto!(n, n′) - return nothing - end - - return N -end - - -# LQ decomposition -# ---------------- -function MatrixAlgebraKit.check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)) - # scalartype checks - @check_eltype L t - @check_eltype Q t - - # space checks - V_Q = fuse(domain(t)) - space(L) == (codomain(t) ← V_Q) || - throw(SpaceMismatch("`lq_full!(t, (L, Q))` requires `space(L) == (codomain(t) ← fuse(domain(t)))`")) - space(Q) == (V_Q ← domain(t)) || - throw(SpaceMismatch("`lq_full!(t, (L, Q))` requires `space(Q) == (fuse(domain(t)) ← domain(t))`")) - - return nothing -end - -function MatrixAlgebraKit.check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)) - # scalartype checks - @check_scalar L t - @check_scalar Q t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(L, codomain(t) ← V_Q) - @check_space(Q, V_Q ← domain(t)) - - return nothing -end - -function MatrixAlgebraKit.check_input(::typeof(lq_null!), t::AbstractTensorMap, N) - # scalartype checks - @check_scalar N t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(domain(t)), V_Q) - @check_space(N, V_N ← domain(t)) - - return nothing -end - -function MatrixAlgebraKit.initialize_output(::typeof(lq_full!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_Q = fuse(domain(t)) - L = similar(t, codomain(t) ← V_Q) - Q = similar(t, V_Q ← domain(t)) - return L, Q -end - -function MatrixAlgebraKit.initialize_output(::typeof(lq_compact!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - L = similar(t, codomain(t) ← V_Q) - Q = similar(t, V_Q ← domain(t)) - return L, Q -end - -function MatrixAlgebraKit.initialize_output(::typeof(lq_null!), t::AbstractTensorMap, - ::MatrixAlgebraKit.AbstractAlgorithm) - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(domain(t)), V_Q) - N = similar(t, V_N ← domain(t)) - return N -end - -function MatrixAlgebraKit.lq_full!(t::AbstractTensorMap, (L, Q), - alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(lq_full!, t, (L, Q)) - - foreachblock(t, L, Q; alg.scheduler) do _, (b, l, q) - l′, q′ = lq_full!(b, (l, q), alg.alg) - # deal with the case where the output is not the same as the input - l === l′ || copyto!(l, l′) - q === q′ || copyto!(q, q′) - return nothing - end - - return L, Q -end - -function MatrixAlgebraKit.lq_compact!(t::AbstractTensorMap, (L, Q), - alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(lq_compact!, t, (L, Q)) - - foreachblock(t, L, Q; alg.scheduler) do _, (b, l, q) - l′, q′ = lq_compact!(b, (l, q), alg.alg) - # deal with the case where the output is not the same as the input - l === l′ || copyto!(l, l′) - q === q′ || copyto!(q, q′) - return nothing - end - - return L, Q -end - -function MatrixAlgebraKit.lq_null!(t::AbstractTensorMap, N, alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(lq_null!, t, N) - - foreachblock(t, N; alg.scheduler) do _, (b, n) - n′ = lq_null!(b, n, alg.alg) - # deal with the case where the output is not the same as the input - n === n′ || copyto!(n, n′) - return nothing - end - - return N -end - -# Polar decomposition -# ------------------- -using MatrixAlgebraKit: PolarViaSVD - -function MatrixAlgebraKit.check_input(::typeof(left_polar!), t, (W, P)) - codomain(t) ≿ domain(t) || - throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) - - # scalartype checks - @check_scalar W t - @check_scalar P t - - # space checks - @check_space(W, space(t)) - @check_space(P, domain(t) ← domain(t)) - - return nothing -end - -# TODO: do we really not want to fuse the spaces? -function MatrixAlgebraKit.initialize_output(::typeof(left_polar!), t::AbstractTensorMap, - ::BlockAlgorithm) - W = similar(t, space(t)) - P = similar(t, domain(t) ← domain(t)) - return W, P -end - -function MatrixAlgebraKit.left_polar!(t::AbstractTensorMap, WP, alg::BlockAlgorithm) - MatrixAlgebraKit.check_input(left_polar!, t, WP) - - foreachblock(t, WP...; alg.scheduler) do _, (b, w, p) - w′, p′ = left_polar!(b, (w, p), alg.alg) - # deal with the case where the output is not the same as the input - w === w′ || copyto!(w, w′) - p === p′ || copyto!(p, p′) - return nothing - end - - return WP -end - -# Trick to relax the checks of "square" if coming from left_orth -function MatrixAlgebraKit.left_orth_polar!(t::AbstractTensorMap, VC, alg) - alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg) - return MatrixAlgebraKit.left_orth_polar!(t, VC, alg′) -end -function MatrixAlgebraKit.left_orth_polar!(t::AbstractTensorMap, WP, alg::BlockAlgorithm) - foreachblock(t, WP...; alg.scheduler) do _, (b, w, p) - w′, p′ = left_polar!(b, (w, p), alg.alg) - # deal with the case where the output is not the same as the input - w === w′ || copyto!(w, w′) - p === p′ || copyto!(p, p′) - return nothing - end - return WP -end - -# Orthogonalization -# ----------------- -function MatrixAlgebraKit.check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)) - # scalartype checks - @check_scalar V t - isnothing(C) || @check_scalar C t - - # space checks - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(V, codomain(t) ← V_C) - isnothing(C) || @check_space(CV_C ← domain(t)) - - return nothing -end - -function MatrixAlgebraKit.check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)) - # scalartype checks - isnothing(C) || @check_scalar C t - @check_scalar Vᴴ t - - # space checks - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - isnothing(C) || @check_space(C, codomain(t) ← V_C) - @check_space(Vᴴ, V_dom ← domain(t)) - - return nothing -end - -function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), t::AbstractTensorMap) - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - V = similar(t, codomain(t) ← V_C) - C = similar(t, V_C ← domain(t)) - return V, C -end - -function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), t::AbstractTensorMap) - V_C = infimum(fuse(codomain(t)), fuse(domain(t))) - C = similar(t, codomain(t) ← V_C) - Vᴴ = similar(t, V_C ← domain(t)) - return C, Vᴴ -end - -# Nullspace -# --------- -function MatrixAlgebraKit.check_input(::typeof(left_null!), t::AbstractTensorMap, N) - # scalartype checks - @check_scalar N t - - # space checks - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(codomain(t)), V_Q) - @check_space(N, codomain(t) ← V_N) - - return nothing -end - -function MatrixAlgebraKit.initialize_output(::typeof(left_null!), t::AbstractTensorMap) - V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) - V_N = ⊖(fuse(codomain(t)), V_Q) - N = similar(t, codomain(t) ← V_N) - return N -end - -# TODO: the following functions shouldn't be necessary if the AbstractArray restrictions are -# removed -function MatrixAlgebraKit.left_null(t::AbstractTensorMap; kwargs...) - return left_null!(MatrixAlgebraKit.copy_input(left_null, t); kwargs...) -end -function MatrixAlgebraKit.left_null!(t::AbstractTensorMap; kwargs...) - N = MatrixAlgebraKit.initialize_output(left_null!, t) - return left_null!(t, N; kwargs...) -end - -function MatrixAlgebraKit.left_null!(t::AbstractTensorMap, N; - trunc=nothing, - kind=isnothing(trunc) ? :qr : :svd, - alg_qr=(; positive=true), - alg_svd=(;)) - MatrixAlgebraKit.check_input(left_null!, t, N) - - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for left_null with kind=$kind")) - end - - if kind == :qr - @info "qr" - alg_qr′ = MatrixAlgebraKit._select_algorithm(qr_null!, t, alg_qr) - return qr_null!(t, N, alg_qr′) - elseif kind == :svd && isnothing(trunc) - @info "svd" - alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd) - # TODO: refactor into separate function - U, _, _ = svd_full!(t, alg_svd′) - for (c, b) in blocks(N) - bU = block(U, c) - m, n = size(bU) - copy!(b, @view(bU[1:m, (n + 1):m])) - end - return N - elseif kind == :svd - @info "svd2" - alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd) - @show U, S, _ = svd_full!(t, alg_svd′) - @show trunc′ = _select_truncation(left_null!, t, @show trunc) - return MatrixAlgebraKit.truncate!(left_null!, (U, S), trunc′) - else - throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) - end -end - -# Truncation -# ---------- -# TODO: technically we could do this truncation in-place, but this might not be worth it -function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), - trunc::MatrixAlgebraKit.TruncationKeepAbove) - atol = max(trunc.atol, norm(S) * trunc.rtol) - V_truncated = spacetype(S)(c => findlast(>=(atol), b.diag) for (c, b) in blocks(S)) - - Ũ = similar(U, codomain(U) ← V_truncated) - for (c, b) in blocks(Ũ) - copy!(b, @view(block(U, c)[:, 1:size(b, 2)])) - end - - S̃ = DiagonalTensorMap{scalartype(S)}(undef, V_truncated) - for (c, b) in blocks(S̃) - copy!(b.diag, @view(block(S, c).diag[1:size(b, 1)])) - end - - Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) - for (c, b) in blocks(Ṽᴴ) - copy!(b, @view(block(Vᴴ, c)[1:size(b, 1), :])) - end - - return Ũ, S̃, Ṽᴴ -end - -function MatrixAlgebraKit.truncate!(::typeof(left_null!), - (U, S)::Tuple{<:AbstractTensorMap, - <:AbstractTensorMap}, - strategy::MatrixAlgebraKit.TruncationStrategy) - extended_S = SectorDict(c => vcat(MatrixAlgebraKit.diagview(b), - zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) - for (c, b) in blocks(S)) - ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) - V_truncated = spacetype(S)(c => length(axes(b, 1)[ind[c]]) for (c, b) in blocks(S)) - Ũ = similar(U, codomain(U) ← V_truncated) - for (c, b) in blocks(Ũ) - copy!(b, @view(block(U, c)[:, ind[c]])) - end - return Ũ -end - -const BlockWiseTruncations = Union{MatrixAlgebraKit.TruncationKeepAbove, - MatrixAlgebraKit.TruncationKeepBelow, - MatrixAlgebraKit.TruncationKeepFiltered} - -# TODO: relative tolerances should be global -function MatrixAlgebraKit.findtruncated(values::SectorDict, strategy::BlockWiseTruncations) - return SectorDict(c => MatrixAlgebraKit.findtruncated(v, strategy) for (c, v) in values) -end -function MatrixAlgebraKit.findtruncated(vals::SectorDict, - strategy::MatrixAlgebraKit.TruncationKeepSorted) - allpairs = mapreduce(vcat, vals) do (c, v) - return map(Base.Fix1(=>, c), axes(v, 1)) - end - by((c, i)) = strategy.sortby(vals[c][i]) - sort!(allpairs; by, strategy.rev) - - howmany = zero(Base.promote_op(dim, valtype(values))) - i = 1 - while i ≤ length(allpairs) - howmany += dim(first(allpairs[i])) - - howmany == strategy.howmany && break - - if howmany > strategy.howmany - i -= 1 - break - end - - i += 1 - end - - ind = SectorDict(c => allpairs[findall(==(c) ∘ first, view(allpairs, 1:i))] - for c in keys(vals)) - filter!(!isempty ∘ last, ind) # TODO: this might not be necessary - - return ind -end diff --git a/src/tensors/truncation.jl b/src/tensors/truncation.jl deleted file mode 100644 index e49cdc94d..000000000 --- a/src/tensors/truncation.jl +++ /dev/null @@ -1,163 +0,0 @@ -# truncation.jl -# -# Implements truncation schemes for truncating a tensor with svd, leftorth or rightorth -abstract type TruncationScheme end - -struct NoTruncation <: TruncationScheme -end -notrunc() = NoTruncation() - -struct TruncationError{T<:Real} <: TruncationScheme - ϵ::T -end -truncerr(epsilon::Real) = TruncationError(epsilon) - -struct TruncationDimension <: TruncationScheme - dim::Int -end -truncdim(d::Int) = TruncationDimension(d) - -struct TruncationSpace{S<:ElementarySpace} <: TruncationScheme - space::S -end -truncspace(space::ElementarySpace) = TruncationSpace(space) - -struct TruncationCutoff{T<:Real} <: TruncationScheme - ϵ::T - add_back::Int -end -truncbelow(epsilon::Real, add_back::Int=0) = TruncationCutoff(epsilon, add_back) - -# Compute the total truncation error given truncation dimensions -function _compute_truncerr(Σdata, truncdim, p=2) - I = keytype(Σdata) - S = scalartype(valtype(Σdata)) - return _norm((c => view(v, (truncdim[c] + 1):length(v)) for (c, v) in Σdata), p, - zero(S)) -end - -# Compute truncation dimensions -function _compute_truncdim(Σdata, ::NoTruncation, p=2) - I = keytype(Σdata) - truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) - return truncdim -end -function _compute_truncdim(Σdata, trunc::TruncationError, p=2) - I = keytype(Σdata) - S = real(eltype(valtype(Σdata))) - truncdim = SectorDict{I,Int}(c => length(Σc) for (c, Σc) in Σdata) - truncerr = zero(S) - while true - cmin = _findnexttruncvalue(Σdata, truncdim, p) - isnothing(cmin) && break - truncdim[cmin] -= 1 - truncerr = _compute_truncerr(Σdata, truncdim, p) - if truncerr > trunc.ϵ - truncdim[cmin] += 1 - break - end - end - return truncdim -end -function _compute_truncdim(Σdata, trunc::TruncationDimension, p=2) - I = keytype(Σdata) - truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) - while sum(dim(c) * d for (c, d) in truncdim) > trunc.dim - cmin = _findnexttruncvalue(Σdata, truncdim, p) - isnothing(cmin) && break - truncdim[cmin] -= 1 - end - return truncdim -end -function _compute_truncdim(Σdata, trunc::TruncationSpace, p=2) - I = keytype(Σdata) - truncdim = SectorDict{I,Int}(c => min(length(v), dim(trunc.space, c)) - for (c, v) in Σdata) - return truncdim -end - -function _compute_truncdim(Σdata, trunc::TruncationCutoff, p=2) - I = keytype(Σdata) - truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) - for (c, v) in Σdata - newdim = findlast(Base.Fix2(>, trunc.ϵ), v) - if newdim === nothing - truncdim[c] = 0 - else - truncdim[c] = newdim - end - end - for i in 1:(trunc.add_back) - cmax = _findnextgrowvalue(Σdata, truncdim, p) - isnothing(cmax) && break - truncdim[cmax] += 1 - end - return truncdim -end - -# Combine truncations -struct MultipleTruncation{T<:Tuple{Vararg{TruncationScheme}}} <: TruncationScheme - truncations::T -end -function Base.:&(a::MultipleTruncation, b::MultipleTruncation) - return MultipleTruncation((a.truncations..., b.truncations...)) -end -function Base.:&(a::MultipleTruncation, b::TruncationScheme) - return MultipleTruncation((a.truncations..., b)) -end -function Base.:&(a::TruncationScheme, b::MultipleTruncation) - return MultipleTruncation((a, b.truncations...)) -end -Base.:&(a::TruncationScheme, b::TruncationScheme) = MultipleTruncation((a, b)) - -function _compute_truncdim(Σdata, trunc::MultipleTruncation, p::Real=2) - truncdim = _compute_truncdim(Σdata, trunc.truncations[1], p) - for k in 2:length(trunc.truncations) - truncdimₖ = _compute_truncdim(Σdata, trunc.truncations[k], p) - for (c, d) in truncdim - truncdim[c] = min(d, truncdimₖ[c]) - end - end - return truncdim -end - -# auxiliary function -function _findnexttruncvalue(Σdata, truncdim::SectorDict{I,Int}, p::Real) where {I<:Sector} - # early return - (isempty(Σdata) || all(iszero, values(truncdim))) && return nothing - - # find some suitable starting candidate - cmin = findfirst(>(0), truncdim) - σmin = Σdata[cmin][truncdim[cmin]] - - # find the actual minimum singular value - for (c, σs) in Σdata - if truncdim[c] > 0 - σ = σs[truncdim[c]] - if σ < σmin - cmin, σmin = c, σ - end - end - end - return cmin -end -function _findnextgrowvalue(Σdata, truncdim::SectorDict{I,Int}, p::Real) where {I<:Sector} - istruncated = SectorDict{I,Bool}(c => (d < length(Σdata[c])) for (c, d) in truncdim) - # early return - (isempty(Σdata) || !any(values(istruncated))) && return nothing - - # find some suitable starting candidate - cmax = findfirst(istruncated) - σmax = Σdata[cmax][truncdim[cmax] + 1] - - # find the actual maximal singular value - for (c, σs) in Σdata - if istruncated[c] - σ = σs[truncdim[c] + 1] - if σ > σmax - cmax, σmax = c, σ - end - end - end - return cmax -end diff --git a/test/factorizations.jl b/test/factorizations.jl index ca9b510fb..17642aa93 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -1,318 +1,340 @@ using TestEnv; TestEnv.activate(); -using Test -using TestExtras -using Random -using TensorKit -using Combinatorics -using TensorKit: ProductSector, fusiontensor, pentagon_equation, hexagon_equation -using TensorOperations -using Base.Iterators: take, product -# using SUNRepresentations: SUNIrrep -# const SU3Irrep = SUNIrrep{3} -using LinearAlgebra: LinearAlgebra -using Zygote: Zygote -using MatrixAlgebraKit +@testsnippet Setup begin + using Test + using TestExtras + using Random + using TensorKit + using Combinatorics + using TensorKit: ProductSector, fusiontensor, pentagon_equation, hexagon_equation + using TensorOperations + using Base.Iterators: take, product + # using SUNRepresentations: SUNIrrep + # const SU3Irrep = SUNIrrep{3} + using LinearAlgebra: LinearAlgebra + using Zygote: Zygote + using MatrixAlgebraKit -const TK = TensorKit + const TK = TensorKit -Random.seed!(1234) + Random.seed!(1234) -smallset(::Type{I}) where {I<:Sector} = take(values(I), 5) -function smallset(::Type{ProductSector{Tuple{I1,I2}}}) where {I1,I2} - iter = product(smallset(I1), smallset(I2)) - s = collect(i ⊠ j for (i, j) in iter if dim(i) * dim(j) <= 6) - return length(s) > 6 ? rand(s, 6) : s -end -function smallset(::Type{ProductSector{Tuple{I1,I2,I3}}}) where {I1,I2,I3} - iter = product(smallset(I1), smallset(I2), smallset(I3)) - s = collect(i ⊠ j ⊠ k for (i, j, k) in iter if dim(i) * dim(j) * dim(k) <= 6) - return length(s) > 6 ? rand(s, 6) : s -end -function randsector(::Type{I}) where {I<:Sector} - s = collect(smallset(I)) - a = rand(s) - while a == one(a) # don't use trivial label + smallset(::Type{I}) where {I<:Sector} = take(values(I), 5) + function smallset(::Type{ProductSector{Tuple{I1,I2}}}) where {I1,I2} + iter = product(smallset(I1), smallset(I2)) + s = collect(i ⊠ j for (i, j) in iter if dim(i) * dim(j) <= 6) + return length(s) > 6 ? rand(s, 6) : s + end + function smallset(::Type{ProductSector{Tuple{I1,I2,I3}}}) where {I1,I2,I3} + iter = product(smallset(I1), smallset(I2), smallset(I3)) + s = collect(i ⊠ j ⊠ k for (i, j, k) in iter if dim(i) * dim(j) * dim(k) <= 6) + return length(s) > 6 ? rand(s, 6) : s + end + function randsector(::Type{I}) where {I<:Sector} + s = collect(smallset(I)) a = rand(s) + while a == one(a) # don't use trivial label + a = rand(s) + end + return a end - return a -end -function hasfusiontensor(I::Type{<:Sector}) - try - fusiontensor(one(I), one(I), one(I)) - return true - catch e - if e isa MethodError - return false - else - rethrow(e) + function hasfusiontensor(I::Type{<:Sector}) + try + fusiontensor(one(I), one(I), one(I)) + return true + catch e + if e isa MethodError + return false + else + rethrow(e) + end end end -end -# spaces -Vtr = (ℂ^3, - (ℂ^4)', - ℂ^5, - ℂ^6, - (ℂ^7)') -Vℤ₂ = (ℂ[Z2Irrep](0 => 1, 1 => 1), - ℂ[Z2Irrep](0 => 1, 1 => 2)', - ℂ[Z2Irrep](0 => 3, 1 => 2)', - ℂ[Z2Irrep](0 => 2, 1 => 3), - ℂ[Z2Irrep](0 => 2, 1 => 5)) -Vfℤ₂ = (ℂ[FermionParity](0 => 1, 1 => 1), - ℂ[FermionParity](0 => 1, 1 => 2)', - ℂ[FermionParity](0 => 3, 1 => 2)', - ℂ[FermionParity](0 => 2, 1 => 3), - ℂ[FermionParity](0 => 2, 1 => 5)) -Vℤ₃ = (ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 2), - ℂ[Z3Irrep](0 => 3, 1 => 1, 2 => 1), - ℂ[Z3Irrep](0 => 2, 1 => 2, 2 => 1)', - ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 3), - ℂ[Z3Irrep](0 => 1, 1 => 3, 2 => 3)') -VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), - ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), - ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', - ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 3), - ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 3)') -VfU₁ = (ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 2), - ℂ[FermionNumber](0 => 3, 1 => 1, -1 => 1), - ℂ[FermionNumber](0 => 2, 1 => 2, -1 => 1)', - ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 3), - ℂ[FermionNumber](0 => 1, 1 => 3, -1 => 3)') -VCU₁ = (ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 2, 1 => 1), - ℂ[CU1Irrep]((0, 0) => 3, (0, 1) => 0, 1 => 1), - ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 0, 1 => 2)', - ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 2, 1 => 1), - ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 1, 1 => 2)') -VSU₂ = (ℂ[SU2Irrep](0 => 3, 1 // 2 => 1), - ℂ[SU2Irrep](0 => 2, 1 => 1), - ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', - ℂ[SU2Irrep](0 => 2, 1 // 2 => 2), - ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') -VfSU₂ = (ℂ[FermionSpin](0 => 3, 1 // 2 => 1), - ℂ[FermionSpin](0 => 2, 1 => 1), - ℂ[FermionSpin](1 // 2 => 1, 1 => 1)', - ℂ[FermionSpin](0 => 2, 1 // 2 => 2), - ℂ[FermionSpin](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') -for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) - V1, V2, V3, V4, V5 = V + # spaces + Vtr = (ℂ^3, + (ℂ^4)', + ℂ^5, + ℂ^6, + (ℂ^7)') + Vℤ₂ = (ℂ[Z2Irrep](0 => 1, 1 => 1), + ℂ[Z2Irrep](0 => 1, 1 => 2)', + ℂ[Z2Irrep](0 => 3, 1 => 2)', + ℂ[Z2Irrep](0 => 2, 1 => 3), + ℂ[Z2Irrep](0 => 2, 1 => 5)) + Vfℤ₂ = (ℂ[FermionParity](0 => 1, 1 => 1), + ℂ[FermionParity](0 => 1, 1 => 2)', + ℂ[FermionParity](0 => 3, 1 => 2)', + ℂ[FermionParity](0 => 2, 1 => 3), + ℂ[FermionParity](0 => 2, 1 => 5)) + Vℤ₃ = (ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 2), + ℂ[Z3Irrep](0 => 3, 1 => 1, 2 => 1), + ℂ[Z3Irrep](0 => 2, 1 => 2, 2 => 1)', + ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 3), + ℂ[Z3Irrep](0 => 1, 1 => 3, 2 => 3)') + VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), + ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), + ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 3), + ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 3)') + VfU₁ = (ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 2), + ℂ[FermionNumber](0 => 3, 1 => 1, -1 => 1), + ℂ[FermionNumber](0 => 2, 1 => 2, -1 => 1)', + ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 3), + ℂ[FermionNumber](0 => 1, 1 => 3, -1 => 3)') + VCU₁ = (ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 2, 1 => 1), + ℂ[CU1Irrep]((0, 0) => 3, (0, 1) => 0, 1 => 1), + ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 0, 1 => 2)', + ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 2, 1 => 1), + ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 1, 1 => 2)') + VSU₂ = (ℂ[SU2Irrep](0 => 3, 1 // 2 => 1), + ℂ[SU2Irrep](0 => 2, 1 => 1), + ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', + ℂ[SU2Irrep](0 => 2, 1 // 2 => 2), + ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') + VfSU₂ = (ℂ[FermionSpin](0 => 3, 1 // 2 => 1), + ℂ[FermionSpin](0 => 2, 1 => 1), + ℂ[FermionSpin](1 // 2 => 1, 1 => 1)', + ℂ[FermionSpin](0 => 2, 1 // 2 => 2), + ℂ[FermionSpin](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') + for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) + V1, V2, V3, V4, V5 = V - @assert V3 * V4 * V2 ≿ V1' * V5' # necessary for leftorth tests - @assert V3 * V4 ≾ V1' * V2' * V5' # necessary for rightorth tests -end + @assert V3 * V4 * V2 ≿ V1' * V5' # necessary for leftorth tests + @assert V3 * V4 ≾ V1' * V2' * V5' # necessary for rightorth tests + end -spacelist = try - if ENV["CI"] == "true" - println("Detected running on CI") - if Sys.iswindows() - (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂) - elseif Sys.isapple() - (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VfU₁, VfSU₂)#, VSU₃) + spacelist = try + if ENV["CI"] == "true" + println("Detected running on CI") + if Sys.iswindows() + (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂) + elseif Sys.isapple() + (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VfU₁, VfSU₂)#, VSU₃) + else + (Vtr, Vℤ₂, Vfℤ₂, VU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) + end else - (Vtr, Vℤ₂, Vfℤ₂, VU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) + (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) end - else + catch (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) end -catch - (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) end +@testitem "left orth" setup = [Setup] begin + function test_leftorth(t, p, alg) + Q, R = @inferred leftorth(t, p; alg) + @test Q * R ≈ permute(t, p) + @test isisometry(Q) + end + + p = ((3, 4, 2), (1, 5)) + elts = (Float32, ComplexF64) + algs = (TensorKit.QR(), TensorKit.QRpos(), TensorKit.QL(), TensorKit.QLpos(), + TensorKit.Polar(), TensorKit.SVD(), TensorKit.SDD()) -function test_leftorth(t, p, alg) - Q, R = @inferred leftorth(t, p; alg) - @test Q * R ≈ permute(t, p) - @test isisometry(Q) - if alg isa Polar - @test isposdef(R) - @test domain(R) == codomain(R) == domain(permute(space(t), p)) + testname(V) = "symmetry: $(TensorKit.type_repr(sectortype(first(V))))" + @timedtestset "$(testname(V))" for V in spacelist + W = ⊗(V...) + for T in elts, alg in algs + t = rand(T, W) + test_leftorth(t, p, alg) + tᴴ = t' + test_leftorth(tᴴ, p, alg) + end end end + function test_leftnull(t, p, alg) N = @inferred leftnull(t, p; alg) @test isisometry(N) - @test norm(N' * permute(t, p)) ≈ 0 atol= 100 * eps(norm(t)) + @test norm(N' * permute(t, p)) ≈ 0 atol = 100 * eps(norm(t)) end # @timedtestset "Factorizations with symmetry: $(sectortype(first(V)))" for V in spacelist - V = collect(spacelist)[2] - V1, V2, V3, V4, V5 = V - W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 - for T in (Float32, ComplexF64), adj in (false, true) - t = adj ? rand(T, W)' : rand(T, W); - @testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QRpos(), TensorKit.QL(), TensorKit.QLpos(), TensorKit.Polar(), TensorKit.SVD(), TensorKit.SDD()) - test_leftorth(t, ((3, 4, 2), (1, 5)), alg) +V = collect(spacelist)[1] +V1, V2, V3, V4, V5 = V +W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 +for T in (Float32, ComplexF64), adj in (false, true) + t = adj ? rand(T, W)' : rand(T, W) + @testset "leftorth with $alg" for alg in + (TensorKit.QR(), TensorKit.QRpos(), TensorKit.QL(), + TensorKit.QLpos(), TensorKit.Polar(), + TensorKit.SVD(), TensorKit.SDD()) + test_leftorth(t, ((3, 4, 2), (1, 5)), alg) + end + @testset "leftnull with $alg" for alg in + (TensorKit.QR(), TensorKit.SVD(), TensorKit.SDD()) + test_leftnull(t, ((3, 4, 2), (1, 5)), alg) + end + @testset "rightorth with $alg" for alg in + (TensorKit.RQ(), TensorKit.RQpos(), + TensorKit.LQ(), TensorKit.LQpos(), + TensorKit.Polar(), TensorKit.SVD(), + TensorKit.SDD()) + L, Q = @constinferred rightorth(t, ((3, 4), (2, 1, 5)); alg=alg) + QQd = Q * Q' + @test QQd ≈ one(QQd) + @test L * Q ≈ permute(t, ((3, 4), (2, 1, 5))) + if alg isa Polar + @test isposdef(L) + @test domain(L) == codomain(L) == space(t, 3) ⊗ space(t, 4) + end + end + @testset "rightnull with $alg" for alg in + (TensorKit.LQ(), TensorKit.SVD(), + TensorKit.SDD()) + M = @constinferred rightnull(t, ((3, 4), (2, 1, 5)); alg=alg) + MMd = M * M' + @test MMd ≈ one(MMd) + @test norm(permute(t, ((3, 4), (2, 1, 5))) * M') < + 100 * eps(norm(t)) + end + @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) + U, S, V = @constinferred tsvd(t, ((3, 4, 2), (1, 5)); alg=alg) + UdU = U' * U + @test UdU ≈ one(UdU) + VVd = V * V' + @test VVd ≈ one(VVd) + t2 = permute(t, ((3, 4, 2), (1, 5))) + @test U * S * V ≈ t2 + + s = LinearAlgebra.svdvals(t2) + s′ = LinearAlgebra.diag(S) + for (c, b) in s + @test b ≈ s′[c] end - @testset "leftnull with $alg" for alg in (TensorKit.QR(), TensorKit.SVD(), TensorKit.SDD()) - test_leftnull(t, ((3, 4, 2), (1, 5)), alg) + end + @testset "cond and rank" begin + t2 = permute(t, ((3, 4, 2), (1, 5))) + d1 = dim(codomain(t2)) + d2 = dim(domain(t2)) + @test rank(t2) == min(d1, d2) + M = leftnull(t2) + @test rank(M) == max(d1, d2) - min(d1, d2) + t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) + @test cond(t3) ≈ one(real(T)) + @test rank(t3) == dim(V1 ⊗ V2) + t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) + t4 = (t4 + t4') / 2 + vals = LinearAlgebra.eigvals(t4) + λmax = maximum(s -> maximum(abs, s), values(vals)) + λmin = minimum(s -> minimum(abs, s), values(vals)) + @test cond(t4) ≈ λmax / λmin + end +end + +@testset "empty tensor" begin + for T in (Float32, ComplexF64) + T = Float64 + t = randn(T, V1 ⊗ V2, zero(V1)) + @testset "leftorth with $alg" for alg in + (TensorKit.QR(), TensorKit.QRpos(), + TensorKit.QL(), TensorKit.QLpos(), + TensorKit.Polar(), TensorKit.SVD(), + TensorKit.SDD()) + Q, R = @constinferred leftorth(t; alg=alg) + @test Q == t + @test dim(Q) == dim(R) == 0 + end + @testset "leftnull with $alg" for alg in + (TensorKit.QR(), TensorKit.SVD(), + TensorKit.SDD()) + N = @constinferred leftnull(t; alg=alg) + @test N' * N ≈ id(domain(N)) + @test N * N' ≈ id(codomain(N)) end @testset "rightorth with $alg" for alg in (TensorKit.RQ(), TensorKit.RQpos(), TensorKit.LQ(), TensorKit.LQpos(), TensorKit.Polar(), TensorKit.SVD(), TensorKit.SDD()) - L, Q = @constinferred rightorth(t, ((3, 4), (2, 1, 5)); alg=alg) - QQd = Q * Q' - @test QQd ≈ one(QQd) - @test L * Q ≈ permute(t, ((3, 4), (2, 1, 5))) - if alg isa Polar - @test isposdef(L) - @test domain(L) == codomain(L) == space(t, 3) ⊗ space(t, 4) - end + L, Q = @constinferred rightorth(copy(t'); alg=alg) + @test Q == t' + @test dim(Q) == dim(L) == 0 end @testset "rightnull with $alg" for alg in (TensorKit.LQ(), TensorKit.SVD(), TensorKit.SDD()) - M = @constinferred rightnull(t, ((3, 4), (2, 1, 5)); alg=alg) - MMd = M * M' - @test MMd ≈ one(MMd) - @test norm(permute(t, ((3, 4), (2, 1, 5))) * M') < - 100 * eps(norm(t)) + M = @constinferred rightnull(copy(t'); alg=alg) + @test M * M' ≈ id(codomain(M)) + @test M' * M ≈ id(domain(M)) end @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) - U, S, V = @constinferred tsvd(t, ((3, 4, 2), (1, 5)); alg=alg) - UdU = U' * U - @test UdU ≈ one(UdU) - VVd = V * V' - @test VVd ≈ one(VVd) - t2 = permute(t, ((3, 4, 2), (1, 5))) - @test U * S * V ≈ t2 - - s = LinearAlgebra.svdvals(t2) - s′ = LinearAlgebra.diag(S) - for (c, b) in s - @test b ≈ s′[c] - end + U, S, V = @constinferred tsvd(t; alg=alg) + @test U == t + @test dim(U) == dim(S) == dim(V) end @testset "cond and rank" begin - t2 = permute(t, ((3, 4, 2), (1, 5))) - d1 = dim(codomain(t2)) - d2 = dim(domain(t2)) - @test rank(t2) == min(d1, d2) - M = leftnull(t2) - @test rank(M) == max(d1, d2) - min(d1, d2) - t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) - @test cond(t3) ≈ one(real(T)) - @test rank(t3) == dim(V1 ⊗ V2) - t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) - t4 = (t4 + t4') / 2 - vals = LinearAlgebra.eigvals(t4) - λmax = maximum(s -> maximum(abs, s), values(vals)) - λmin = minimum(s -> minimum(abs, s), values(vals)) - @test cond(t4) ≈ λmax / λmin - end - end - @testset "empty tensor" begin - for T in (Float32, ComplexF64) - t = randn(T, V1 ⊗ V2, zero(V1)) - @testset "leftorth with $alg" for alg in - (TensorKit.QR(), TensorKit.QRpos(), - TensorKit.QL(), TensorKit.QLpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - Q, R = @constinferred leftorth(t; alg=alg) - @test Q == t - @test dim(Q) == dim(R) == 0 - end - @testset "leftnull with $alg" for alg in - (TensorKit.QR(), TensorKit.SVD(), - TensorKit.SDD()) - N = @constinferred leftnull(t; alg=alg) - @test N' * N ≈ id(domain(N)) - @test N * N' ≈ id(codomain(N)) - end - @testset "rightorth with $alg" for alg in - (TensorKit.RQ(), TensorKit.RQpos(), - TensorKit.LQ(), TensorKit.LQpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - L, Q = @constinferred rightorth(copy(t'); alg=alg) - @test Q == t' - @test dim(Q) == dim(L) == 0 - end - @testset "rightnull with $alg" for alg in - (TensorKit.LQ(), TensorKit.SVD(), - TensorKit.SDD()) - M = @constinferred rightnull(copy(t'); alg=alg) - @test M * M' ≈ id(codomain(M)) - @test M' * M ≈ id(domain(M)) - end - @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) - U, S, V = @constinferred tsvd(t; alg=alg) - @test U == t - @test dim(U) == dim(S) == dim(V) - end - @testset "cond and rank" begin - @test rank(t) == 0 - W2 = zero(V1) * zero(V2) - t2 = rand(W2, W2) - @test rank(t2) == 0 - @test cond(t2) == 0.0 - end + @test rank(t) == 0 + W2 = zero(V1) * zero(V2) + t2 = rand(W2, W2) + @test rank(t2) == 0 + @test cond(t2) == 0.0 end end - @testset "eig and isposdef" begin - for T in (Float32, ComplexF64) - t = rand(T, V1 ⊗ V1' ⊗ V2 ⊗ V2') - D, V = eigen(t, ((1, 3), (2, 4))) - t2 = permute(t, ((1, 3), (2, 4))) - @test t2 * V ≈ V * D +end +@testset "eig and isposdef" begin + for T in (Float32, ComplexF64) + t = rand(T, V1 ⊗ V1' ⊗ V2 ⊗ V2') + D, V = eigen(t, ((1, 3), (2, 4))) + t2 = permute(t, ((1, 3), (2, 4))) + @test t2 * V ≈ V * D - d = LinearAlgebra.eigvals(t2; sortby=nothing) - d′ = LinearAlgebra.diag(D) - for (c, b) in d - @test b ≈ d′[c] - end + d = LinearAlgebra.eigvals(t2; sortby=nothing) + d′ = LinearAlgebra.diag(D) + for (c, b) in d + @test b ≈ d′[c] + end - # Somehow moving these test before the previous one gives rise to errors - # with T=Float32 on x86 platforms. Is this an OpenBLAS issue? - VdV = V' * V - VdV = (VdV + VdV') / 2 - @test isposdef(VdV) + # Somehow moving these test before the previous one gives rise to errors + # with T=Float32 on x86 platforms. Is this an OpenBLAS issue? + VdV = V' * V + VdV = (VdV + VdV') / 2 + @test isposdef(VdV) - @test !isposdef(t2) # unlikely for non-hermitian map - t2 = (t2 + t2') - D, V = eigen(t2) - VdV = V' * V - @test VdV ≈ one(VdV) - D̃, Ṽ = @constinferred eigh(t2) - @test D ≈ D̃ - @test V ≈ Ṽ - λ = minimum(minimum(real(LinearAlgebra.diag(b))) - for (c, b) in blocks(D)) - @test cond(Ṽ) ≈ one(real(T)) - @test isposdef(t2) == isposdef(λ) - @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) - @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) - end + @test !isposdef(t2) # unlikely for non-hermitian map + t2 = (t2 + t2') + D, V = eigen(t2) + VdV = V' * V + @test VdV ≈ one(VdV) + D̃, Ṽ = @constinferred eigh(t2) + @test D ≈ D̃ + @test V ≈ Ṽ + λ = minimum(minimum(real(LinearAlgebra.diag(b))) + for (c, b) in blocks(D)) + @test cond(Ṽ) ≈ one(real(T)) + @test isposdef(t2) == isposdef(λ) + @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) + @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) end - @testset "Tensor truncation" begin - for T in (Float32, ComplexF64), p in (1, 2, 3, Inf), adj in (false, true) - t = adj ? rand(T, V1 ⊗ V2 ⊗ V3, V4 ⊗ V5) : rand(T, V4 ⊗ V5, V1 ⊗ V2 ⊗ V3)' +end +@testset "Tensor truncation" begin + for T in (Float32, ComplexF64), p in (1, 2, 3, Inf), adj in (false, true) + t = adj ? rand(T, V1 ⊗ V2 ⊗ V3, V4 ⊗ V5) : rand(T, V4 ⊗ V5, V1 ⊗ V2 ⊗ V3)' - U₀, S₀, V₀, = tsvd(t) - t = rmul!(t, 1 / norm(S₀, p)) - U, S, V, ϵ = @constinferred tsvd(t; trunc=truncerr(5e-1), p=p) - # @show p, ϵ - # @show domain(S) - # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncerr(nextfloat(ϵ)), p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S)))), - p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - # results with truncationcutoff cannot be compared because they don't take degeneracy into account, and thus truncate differently - U, S, V, ϵ = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀))), p=p) - # @show p, ϵ - # @show domain(S) - # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - end + U₀, S₀, V₀, = tsvd(t) + t = rmul!(t, 1 / norm(S₀, p)) + U, S, V, ϵ = @constinferred tsvd(t; trunc=truncerr(5e-1), p=p) + # @show p, ϵ + # @show domain(S) + # @test min(space(S,1), space(S₀,1)) != space(S₀,1) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncerr(nextfloat(ϵ)), p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S)))), + p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) + # results with truncationcutoff cannot be compared because they don't take degeneracy into account, and thus truncate differently + U, S, V, ϵ = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀))), p=p) + # @show p, ϵ + # @show domain(S) + # @test min(space(S,1), space(S₀,1)) != space(S₀,1) + U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) end end +# end diff --git a/test/tensors.jl b/test/tensors.jl index 30526f2c1..4280348fb 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -454,10 +454,11 @@ for V in spacelist QdQ = Q' * Q @test QdQ ≈ one(QdQ) @test Q * R ≈ permute(t, ((3, 4, 2), (1, 5))) - if alg isa Polar - @test isposdef(R) - @test domain(R) == codomain(R) == space(t, 1)' ⊗ space(t, 5)' - end + # removed since leftorth now merges legs! + # if alg isa Polar + # @test isposdef(R) + # @test domain(R) == codomain(R) == space(t, 1)' ⊗ space(t, 5)' + # end end @testset "leftnull with $alg" for alg in (TensorKit.QR(), TensorKit.SVD(), @@ -477,10 +478,11 @@ for V in spacelist QQd = Q * Q' @test QQd ≈ one(QQd) @test L * Q ≈ permute(t, ((3, 4), (2, 1, 5))) - if alg isa Polar - @test isposdef(L) - @test domain(L) == codomain(L) == space(t, 3) ⊗ space(t, 4) - end + # removed since rightorth now merges legs! + # if alg isa Polar + # @test isposdef(L) + # @test domain(L) == codomain(L) == space(t, 3) ⊗ space(t, 4) + # end end @testset "rightnull with $alg" for alg in (TensorKit.LQ(), TensorKit.SVD(), @@ -615,23 +617,36 @@ for V in spacelist for t in ts U₀, S₀, V₀, = tsvd(t) t = rmul!(t, 1 / norm(S₀, p)) - U, S, V, ϵ = @constinferred tsvd(t; trunc=truncerr(5e-1), p=p) + U, S, V = @constinferred tsvd(t; trunc=truncerr(5e-1, p)) + ϵ = TensorKit._norm(LinearAlgebra.svdvals(U * S * V - t), p, + zero(scalartype(S))) + p == 2 && @test ϵ < 5e-1 # @show p, ϵ # @show domain(S) # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncerr(nextfloat(ϵ)), p=p) + U′, S′, V′ = tsvd(t; trunc=truncerr(ϵ + 10eps(ϵ), p)) + ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, + zero(scalartype(S))) + @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S)))), - p=p) + U′, S′, V′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S))))) + ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, + zero(scalartype(S))) @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) + U′, S′, V′ = tsvd(t; trunc=truncspace(space(S, 1))) + ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, + zero(scalartype(S))) @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) # results with truncationcutoff cannot be compared because they don't take degeneracy into account, and thus truncate differently - U, S, V, ϵ = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀))), p=p) + U, S, V = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀)))) + ϵ = TensorKit._norm(LinearAlgebra.svdvals(U * S * V - t), p, + zero(scalartype(S))) # @show p, ϵ # @show domain(S) # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) + U′, S′, V′ = tsvd(t; trunc=truncspace(space(S, 1))) + ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, + zero(scalartype(S))) @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) end end @@ -691,8 +706,8 @@ for V in spacelist for T in (Float32, ComplexF64) tA = rand(T, V1 ⊗ V3, V1 ⊗ V3) tB = rand(T, V2 ⊗ V4, V2 ⊗ V4) - tA = 3 // 2 * leftorth(tA; alg=Polar())[1] - tB = 1 // 5 * leftorth(tB; alg=Polar())[1] + tA = 3 // 2 * leftpolar(tA)[1] + tB = 1 // 5 * leftpolar(tB)[1] tC = rand(T, V1 ⊗ V3, V2 ⊗ V4) t = @constinferred sylvester(tA, tB, tC) @test codomain(t) == V1 ⊗ V3 From 7d2cff0a840055043e3892c6a8ed7c7ca49f2cf7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 11 Jun 2025 21:25:39 -0400 Subject: [PATCH 031/126] Clean up truncation --- src/tensors/factorizations/truncation.jl | 205 ++++++----------------- 1 file changed, 48 insertions(+), 157 deletions(-) diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index e9806645a..4f7cec733 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -1,149 +1,26 @@ -# truncation.jl -# -# Implements truncation schemes for truncating a tensor with svd, leftorth or rightorth - +# Strategies +# ---------- notrunc() = NoTruncation() + # deprecate const TruncationScheme = TruncationStrategy +@deprecate truncdim(d::Int) truncrank(d) +@deprecate truncbelow(ϵ::Real, add_back::Int=0) trunctol(ϵ) +# TODO: add this to MatrixAlgebraKit struct TruncationError{T<:Real} <: TruncationStrategy ϵ::T p::Real end truncerr(epsilon::Real, p::Real=2) = TruncationError(epsilon, p) -# struct TruncationDimension <: TruncationScheme -# dim::Int -# end -@deprecate truncdim(d::Int) truncrank(d) - struct TruncationSpace{S<:ElementarySpace} <: TruncationStrategy space::S end truncspace(space::ElementarySpace) = TruncationSpace(space) -struct TruncationCutoff{T<:Real} <: TruncationStrategy - ϵ::T - add_back::Int -end -@deprecate truncbelow(ϵ::Real, add_back::Int=0) begin - add_back == 0 || @warn "add_back is ignored" - trunctol(ϵ) -end -# truncbelow(epsilon::Real, add_back::Int=0) = TruncationCutoff(epsilon, add_back) - -# Compute the total truncation error given truncation dimensions -function _compute_truncerr(Σdata, truncdim, p=2) - I = keytype(Σdata) - S = scalartype(valtype(Σdata)) - return TensorKit._norm((c => @view(v[(get(truncdim, c, 0) + 1):end]) - for (c, v) in Σdata), - p, zero(S)) -end - -# Compute truncation dimensions -# function _compute_truncdim(Σdata, ::NoTruncation, p=2) -# I = keytype(Σdata) -# truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) -# return truncdim -# end -# function _compute_truncdim(Σdata, trunc::TruncationDimension, p=2) -# I = keytype(Σdata) -# truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) -# while sum(dim(c) * d for (c, d) in truncdim) > trunc.dim -# cmin = _findnexttruncvalue(Σdata, truncdim, p) -# isnothing(cmin) && break -# truncdim[cmin] -= 1 -# end -# return truncdim -# end -function _compute_truncdim(Σdata, trunc::TruncationSpace, p=2) - I = keytype(Σdata) - truncdim = SectorDict{I,Int}(c => min(length(v), dim(trunc.space, c)) - for (c, v) in Σdata) - return truncdim -end - -# function _compute_truncdim(Σdata, trunc::TruncationCutoff, p=2) -# I = keytype(Σdata) -# truncdim = SectorDict{I,Int}(c => length(v) for (c, v) in Σdata) -# for (c, v) in Σdata -# newdim = findlast(Base.Fix2(>, trunc.ϵ), v) -# if newdim === nothing -# truncdim[c] = 0 -# else -# truncdim[c] = newdim -# end -# end -# for i in 1:(trunc.add_back) -# cmax = _findnextgrowvalue(Σdata, truncdim, p) -# isnothing(cmax) && break -# truncdim[cmax] += 1 -# end -# return truncdim -# end - -# Combine truncations -# struct MultipleTruncation{T<:Tuple{Vararg{TruncationScheme}}} <: TruncationScheme -# truncations::T -# end -# function Base.:&(a::MultipleTruncation, b::MultipleTruncation) -# return MultipleTruncation((a.truncations..., b.truncations...)) -# end -# function Base.:&(a::MultipleTruncation, b::TruncationScheme) -# return MultipleTruncation((a.truncations..., b)) -# end -# function Base.:&(a::TruncationScheme, b::MultipleTruncation) -# return MultipleTruncation((a, b.truncations...)) -# end -# Base.:&(a::TruncationScheme, b::TruncationScheme) = MultipleTruncation((a, b)) - -# function _compute_truncdim(Σdata, trunc::MultipleTruncation, p::Real=2) -# truncdim = _compute_truncdim(Σdata, trunc.truncations[1], p) -# for k in 2:length(trunc.truncations) -# truncdimₖ = _compute_truncdim(Σdata, trunc.truncations[k], p) -# for (c, d) in truncdim -# truncdim[c] = min(d, truncdimₖ[c]) -# end -# end -# return truncdim -# end - -# auxiliary function -function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector} - # early return - (isempty(S) || all(iszero, values(truncdim))) && return nothing - σmin, imin = findmin(keys(truncdim)) do c - d = truncdim[c] - return S[c][d] - end - return σmin, keys(truncdim)[imin] -end - -function _findnextgrowvalue(Σdata, truncdim::SectorDict{I,Int}, p::Real) where {I<:Sector} - istruncated = SectorDict{I,Bool}(c => (d < length(Σdata[c])) for (c, d) in truncdim) - # early return - (isempty(Σdata) || !any(values(istruncated))) && return nothing - - # find some suitable starting candidate - cmax = findfirst(istruncated) - σmax = Σdata[cmax][truncdim[cmax] + 1] - - # find the actual maximal singular value - for (c, σs) in Σdata - if istruncated[c] - σ = σs[truncdim[c] + 1] - if σ > σmax - cmax, σmax = c, σ - end - end - end - return cmax -end - # Truncation # ---------- - function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ)::_T_USVᴴ, strategy::TruncationStrategy) ind = findtruncated_sorted(diagview(S), strategy) V_truncated = spacetype(S)(c => length(I) for (c, I) in ind) @@ -172,29 +49,60 @@ function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ)::_T_USVᴴ, strategy::Trun return Ũ, S̃, Ṽᴴ end -function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove) - atol = if strategy.rtol > 0 - max(strategy.atol, _norm(S, strategy.p) * strategy.rtol) - else - strategy.atol +function truncate!(::typeof(left_null!), + (U, S)::Tuple{<:AbstractTensorMap, + <:AbstractTensorMap}, + strategy::MatrixAlgebraKit.TruncationStrategy) + extended_S = SectorDict(c => vcat(diagview(b), + zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) + for (c, b) in blocks(S)) + ind = findtruncated(extended_S, strategy) + V_truncated = spacetype(S)(c => length(axes(b, 1)[ind[c]]) for (c, b) in blocks(S)) + Ũ = similar(U, codomain(U) ← V_truncated) + for (c, b) in blocks(Ũ) + copy!(b, @view(block(U, c)[:, ind[c]])) end + return Ũ +end + +# Find truncation +# --------------- +# auxiliary functions +rtol_to_atol(S, p, atol, rtol) = rtol > 0 ? max(atol, _norm(S, p) * rtol) : atol + +function _compute_truncerr(Σdata, truncdim, p=2) + I = keytype(Σdata) + S = scalartype(valtype(Σdata)) + return TensorKit._norm((c => @view(v[(get(truncdim, c, 0) + 1):end]) + for (c, v) in Σdata), + p, zero(S)) +end + +function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector} + # early return + (isempty(S) || all(iszero, values(truncdim))) && return nothing + σmin, imin = findmin(keys(truncdim)) do c + d = truncdim[c] + return S[c][d] + end + return σmin, keys(truncdim)[imin] +end + +# sorted implementations +function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove) + atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol)) return SectorDict(c => findtrunc(d) for (c, d) in Sd) end function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepBelow) - atol = if strategy.rtol > 0 - max(strategy.atol, _norm(S, strategy.p) * strategy.rtol) - else - strategy.atol - end + atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) findtrunc = Base.Fix2(findtruncated_sorted, truncabove(atol)) return SectorDict(c => findtrunc(d) for (c, d) in Sd) end function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError) I = keytype(Sd) - S = real(scalartype(valtype(Sd))) truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd) while true next = _findnexttruncvalue(Sd, truncdim) @@ -216,7 +124,6 @@ end function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted) @assert strategy.by === abs && strategy.rev == true "Not implemented" I = keytype(Sd) - S = real(scalartype(valtype(Sd))) truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd) totaldim = sum(dim(c) * d for (c, d) in truncdim; init=0) while true @@ -252,19 +159,3 @@ function findtruncated_sorted(Sd::SectorDict, strategy::TruncationIntersection) return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...) for c in intersect(map(keys, inds)...)) end - -function MatrixAlgebraKit.truncate!(::typeof(left_null!), - (U, S)::Tuple{<:AbstractTensorMap, - <:AbstractTensorMap}, - strategy::MatrixAlgebraKit.TruncationStrategy) - extended_S = SectorDict(c => vcat(MatrixAlgebraKit.diagview(b), - zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) - for (c, b) in blocks(S)) - ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) - V_truncated = spacetype(S)(c => length(axes(b, 1)[ind[c]]) for (c, b) in blocks(S)) - Ũ = similar(U, codomain(U) ← V_truncated) - for (c, b) in blocks(Ũ) - copy!(b, @view(block(U, c)[:, ind[c]])) - end - return Ũ -end From 2e16b29320021e8ff89ea78e6e54a1c51f44ec36 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 11:46:14 -0400 Subject: [PATCH 032/126] Update tuple formatting --- ext/TensorKitChainRulesCoreExt/factorizations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index a91c1dbc0..be72a1cec 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -25,8 +25,8 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ) Δt = similar(t) foreachblock(Δt) do (c, b) - USVᴴc = block(U, c), block(Σ, c), block(V⁺, c) - ΔUSVᴴc = block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c) + USVᴴc = (block(U, c), block(Σ, c), block(V⁺, c)) + ΔUSVᴴc = (block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c)) svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc) return nothing end From 86eae7e757260775bf4cce2b054ca85562d04dd5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 11:46:48 -0400 Subject: [PATCH 033/126] Fix scheduler selection --- src/tensors/backends.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/backends.jl b/src/tensors/backends.jl index 1083115b9..0fc8f99f6 100644 --- a/src/tensors/backends.jl +++ b/src/tensors/backends.jl @@ -2,7 +2,7 @@ # ------------------------ function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...) return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs) - Threads.nthreads() > 1 ? SerialScheduler() : DynamicScheduler() + Threads.nthreads() == 1 ? SerialScheduler() : DynamicScheduler() else OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...) end From 62c92934009094e270136ca9212d187db5ec2e3b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 11:48:38 -0400 Subject: [PATCH 034/126] Retain `dual` in `ominus` --- src/spaces/gradedspace.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spaces/gradedspace.jl b/src/spaces/gradedspace.jl index ddc08046d..568b13475 100644 --- a/src/spaces/gradedspace.jl +++ b/src/spaces/gradedspace.jl @@ -150,9 +150,10 @@ function ⊕(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector} end function ⊖(V::GradedSpace{I}, W::GradedSpace{I}) where {I<:Sector} - V ≿ W && isdual(V) == isdual(W) || + dual = isdual(V) + V ≿ W && dual == isdual(W) || throw(SpaceMismatch("$(W) is not a subspace of $(V)")) - return typeof(V)(c => dim(V, c) - dim(W, c) for c in sectors(V)) + return typeof(V)(c => dim(V, c) - dim(W, c) for c in sectors(V); dual) end function flip(V::GradedSpace{I}) where {I<:Sector} From 32954debadedf08f87e23efa00c3160f1ee27160 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 12:48:25 -0400 Subject: [PATCH 035/126] Update blockiterator --- src/tensors/blockiterator.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index 984facd1b..edbaa27c8 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -15,27 +15,32 @@ Base.length(iter::BlockIterator) = length(iter.structure) Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...) # TODO: fast-path when structures are the same? -# TODO: do we want f(c, bs...) or f(c, bs)? # TODO: implement scheduler -# TODO: do we prefer `blocks(t, ts...)` instead or as well? """ - foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; [scheduler]) + foreachblock(f, ts::AbstractTensorMap...; [scheduler]) Apply `f` to each block of `t` and the corresponding blocks of `ts`. Optionally, `scheduler` can be used to parallelize the computation. This function is equivalent to the following loop: ```julia -for (c, b) in blocks(t) - bs = (b, block.(ts, c)...) +for c in union(blocksectors.(ts)...) + bs = map(t -> block(t, c), ts) f(c, bs) end ``` """ function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler=nothing) - allsectors = union(blocksectors(t), blocksectors.(ts)...) + tensors = (t, ts...) + allsectors = union(blocksectors.(tensors)...) foreach(allsectors) do c - return f(c, map(Base.Fix2(block, c), (t, ts...))) + return f(c, block.(tensors, Ref(c))) + end + return nothing +end +function foreachblock(f, t::AbstractTensorMap; scheduler=nothing) + foreach(blocks(t)) do (c, b) + return f(c, (b,)) end return nothing end From c913ca44474bb82417a4eabaf5c2228960e9d601 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 16:58:44 -0400 Subject: [PATCH 036/126] Update svd rrule --- .../factorizations.jl | 169 ++---------------- 1 file changed, 15 insertions(+), 154 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index be72a1cec..e51303085 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -3,40 +3,32 @@ using MatrixAlgebraKit: svd_compact_pullback! # Factorizations rules # -------------------- function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; - trunc::TensorKit.TruncationScheme=TensorKit.notrunc(), - alg::Union{TensorKit.SVD,TensorKit.SDD}=TensorKit.SDD()) - U, Σ, V⁺, truncerr = tsvd(t; trunc=TensorKit.notrunc(), alg) - - if !(trunc == TensorKit.notrunc()) && !isempty(blocksectors(t)) - Σdata = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(Σ)) - - truncdim = TensorKit._compute_truncdim(Σdata, trunc; p=2) - truncerr = TensorKit._compute_truncerr(Σdata, truncdim; p=2) - - SVDdata = TensorKit.SectorDict(c => (block(U, c), Σc, block(V⁺, c)) - for (c, Σc) in Σdata) - - Ũ, Σ̃, Ṽ⁺ = TensorKit._create_svdtensors(t, SVDdata, truncdim) + trunc::TruncationStrategy=TensorKit.notrunc(), + kwargs...) + # TODO: I think we can use tsvd! here without issues because we don't actually require + # the data of `t` anymore. + USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), alg) + + if trunc != TensorKit.notrunc() && !isempty(blocksectors(t)) + USVᴴ′ = MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, trunc) else - Ũ, Σ̃, Ṽ⁺ = U, Σ, V⁺ + USVᴴ′ = USVᴴ end - function tsvd!_pullback(ΔUSVϵ) - ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ) + function tsvd!_pullback(ΔUSVᴴ′) + ΔUSVᴴ = unthunk.(ΔUSVᴴ′) Δt = similar(t) foreachblock(Δt) do (c, b) - USVᴴc = (block(U, c), block(Σ, c), block(V⁺, c)) - ΔUSVᴴc = (block(ΔU, c), block(ΔΣ, c), block(ΔV⁺, c)) + USVᴴc = block.(USVᴴ, Ref(c)) + ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c)) svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc) return nothing end return NoTangent(), Δt end - function tsvd!_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent() - end + tsvd!_pullback(::NTuple{3,ZeroTangent}) = NoTangent(), ZeroTangent() - return (Ũ, Σ̃, Ṽ⁺, truncerr), tsvd!_pullback + return USVᴴ′, tsvd!_pullback end function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap) @@ -173,137 +165,6 @@ function uppertriangularind(A::AbstractMatrix) return I end -# SVD_pullback: pullback implementation for general (possibly truncated) SVD -# -# Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as -# cotangent ΔU, ΔS, ΔVd variables of truncated SVD -# -# Checks whether the cotangent variables are such that they would couple to gauge-dependent -# degrees of freedom (phases of singular vectors), and prints a warning if this is the case -# -# An implementation that only uses U, S, and Vd from truncated SVD is also possible, but -# requires solving a Sylvester equation, which does not seem to be supported on GPUs. -# -# Other implementation considerations for GPU compatibility: -# no scalar indexing, lots of broadcasting and views -# -# function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector, -# Vd::AbstractMatrix, ΔU, ΔS, ΔVd; -# tol::Real=default_pullback_gaugetol(S)) - -# # Basic size checks and determination -# m, n = size(U, 1), size(Vd, 2) -# size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch()) -# p = -1 -# if !(ΔU isa AbstractZero) -# m == size(ΔU, 1) || throw(DimensionMismatch()) -# p = size(ΔU, 2) -# end -# if !(ΔVd isa AbstractZero) -# n == size(ΔVd, 2) || throw(DimensionMismatch()) -# if p == -1 -# p = size(ΔVd, 1) -# else -# p == size(ΔVd, 1) || throw(DimensionMismatch()) -# end -# end -# if !(ΔS isa AbstractZero) -# if p == -1 -# p = length(ΔS) -# else -# p == length(ΔS) || throw(DimensionMismatch()) -# end -# end -# Up = view(U, :, 1:p) -# Vp = view(Vd, 1:p, :)' -# Sp = view(S, 1:p) - -# # rank -# r = searchsortedlast(S, tol; rev=true) - -# # compute antihermitian part of projection of ΔU and ΔV onto U and V -# # also already subtract this projection from ΔU and ΔV -# if !(ΔU isa AbstractZero) -# UΔU = Up' * ΔU -# aUΔU = rmul!(UΔU - UΔU', 1 / 2) -# if m > p -# ΔU -= Up * UΔU -# end -# else -# aUΔU = fill!(similar(U, (p, p)), 0) -# end -# if !(ΔVd isa AbstractZero) -# VΔV = Vp' * ΔVd' -# aVΔV = rmul!(VΔV - VΔV', 1 / 2) -# if n > p -# ΔVd -= VΔV' * Vp' -# end -# else -# aVΔV = fill!(similar(Vd, (p, p)), 0) -# end - -# # check whether cotangents arise from gauge-invariance objective function -# mask = abs.(Sp' .- Sp) .< tol -# Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) -# if p > r -# rprange = (r + 1):p -# Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf)) -# Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf)) -# end -# Δgauge < tol || -# @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - -# UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+ -# (aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol) -# if !(ΔS isa ZeroTangent) -# UdΔAV[diagind(UdΔAV)] .+= real.(ΔS) -# # in principle, ΔS is real, but maybe not if coming from an anyonic tensor -# end -# mul!(ΔA, Up, UdΔAV * Vp') - -# if r > p # contribution from truncation -# Ur = view(U, :, (p + 1):r) -# Vr = view(Vd, (p + 1):r, :)' -# Sr = view(S, (p + 1):r) - -# if !(ΔU isa AbstractZero) -# UrΔU = Ur' * ΔU -# if m > r -# ΔU -= Ur * UrΔU # subtract this part from ΔU -# end -# else -# UrΔU = fill!(similar(U, (r - p, p)), 0) -# end -# if !(ΔVd isa AbstractZero) -# VrΔV = Vr' * ΔVd' -# if n > r -# ΔVd -= VrΔV' * Vr' # subtract this part from ΔV -# end -# else -# VrΔV = fill!(similar(Vd, (r - p, p)), 0) -# end - -# X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+ -# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) -# Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .- -# (UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol)) - -# # ΔA += Ur * X * Vp' + Up * Y' * Vr' -# mul!(ΔA, Ur, X * Vp', 1, 1) -# mul!(ΔA, Up * Y', Vr', 1, 1) -# end - -# if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)] -# # ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp' -# mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1) -# end -# if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)] -# # ΔA += U * (safe_inv.(Sp, tol) .* ΔVd) -# mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1) -# end -# return ΔA -# end - function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; tol::Real=default_pullback_gaugetol(D)) From c4473c26d46904fd74edccfc7eaf02864ebe4ddc Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 17:04:26 -0400 Subject: [PATCH 037/126] Update eig(h) rrule --- .../factorizations.jl | 117 ++++-------------- 1 file changed, 21 insertions(+), 96 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index e51303085..7ba026c95 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,4 +1,4 @@ -using MatrixAlgebraKit: svd_compact_pullback! +using MatrixAlgebraKit: svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback! # Factorizations rules # -------------------- @@ -46,47 +46,41 @@ function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTenso end function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...) - D, V = eig(t; kwargs...) + DV = eig(t; kwargs...) - function eig!_pullback((_ΔD, _ΔV)) - ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV) + function eig!_pullback(ΔDV′) + ΔDV = unthunk.(ΔDV′) Δt = similar(t) - for (c, b) in blocks(Δt) - Dc, Vc = block(D, c), block(V, c) - ΔDc, ΔVc = block(ΔD, c), block(ΔV, c) - Ddc = view(Dc, diagind(Dc)) - ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc)) - eig_pullback!(b, Ddc, Vc, ΔDdc, ΔVc) + foreachblock(Δt) do (c, b) + DVc = block.(DV, Ref(c)) + ΔDVc = block.(ΔDV, Ref(c)) + eig_full_pullback!(b, DVc, ΔDVc) + return nothing end return NoTangent(), Δt end - function eig!_pullback(::Tuple{ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent() - end + eig!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() - return (D, V), eig!_pullback + return DV, eig!_pullback end function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...) - D, V = eigh(t; kwargs...) + DV = eigh(t; kwargs...) - function eigh!_pullback((_ΔD, _ΔV)) - ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV) + function eigh!_pullback(ΔDV′) + ΔDV = unthunk.(ΔDV′) Δt = similar(t) - for (c, b) in blocks(Δt) - Dc, Vc = block(D, c), block(V, c) - ΔDc, ΔVc = block(ΔD, c), block(ΔV, c) - Ddc = view(Dc, diagind(Dc)) - ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc)) - eigh_pullback!(b, Ddc, Vc, ΔDdc, ΔVc) + foreachblock(Δt) do (c, b) + DVc = block.(DV, Ref(c)) + ΔDVc = block.(ΔDV, Ref(c)) + eigh_full_pullback!(b, DVc, ΔDVc) + return nothing end return NoTangent(), Δt end - function eigh!_pullback(::Tuple{ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent() - end + eigh!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() - return (D, V), eigh!_pullback + return DV, eigh!_pullback end function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap; @@ -165,75 +159,6 @@ function uppertriangularind(A::AbstractMatrix) return I end -function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; - tol::Real=default_pullback_gaugetol(D)) - - # Basic size checks and determination - n = LinearAlgebra.checksquare(V) - n == length(D) || throw(DimensionMismatch()) - - if !(ΔV isa AbstractZero) - VdΔV = V' * ΔV - - mask = abs.(transpose(D) .- D) .< tol - Δgauge = norm(view(VdΔV, mask), Inf) - Δgauge < tol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - - VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol)) - - if !(ΔD isa AbstractZero) - view(VdΔV, diagind(VdΔV)) .+= ΔD - end - PΔV = V' \ VdΔV - if eltype(ΔA) <: Real - ΔAc = mul!(VdΔV, PΔV, V') # recycle VdΔV memory - ΔA .= real.(ΔAc) - else - mul!(ΔA, PΔV, V') - end - else - PΔV = V' \ Diagonal(ΔD) - if eltype(ΔA) <: Real - ΔAc = PΔV * V' - ΔA .= real.(ΔAc) - else - mul!(ΔA, PΔV, V') - end - end - return ΔA -end - -function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV; - tol::Real=default_pullback_gaugetol(D)) - - # Basic size checks and determination - n = LinearAlgebra.checksquare(V) - n == length(D) || throw(DimensionMismatch()) - - if !(ΔV isa AbstractZero) - VdΔV = V' * ΔV - aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2) - - mask = abs.(D' .- D) .< tol - Δgauge = norm(view(aVdΔV, mask)) - Δgauge < tol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - - aVdΔV .*= safe_inv.(D' .- D, tol) - - if !(ΔD isa AbstractZero) - view(aVdΔV, diagind(aVdΔV)) .+= real.(ΔD) - # in principle, ΔD is real, but maybe not if coming from an anyonic tensor - end - # recylce VdΔV space - mul!(ΔA, mul!(VdΔV, V, aVdΔV), V') - else - mul!(ΔA, V * Diagonal(ΔD), V') - end - return ΔA -end - function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR; tol::Real=default_pullback_gaugetol(R)) Rd = view(R, diagind(R)) From d5031b246e6375ef0a289756f9291d8f841b15aa Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 17:41:48 -0400 Subject: [PATCH 038/126] Implement `isposdef` --- src/tensors/factorizations/factorizations.jl | 19 ------------------- src/tensors/factorizations/interface.jl | 14 +++++++++++++- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index a83b77a35..104cd7a4e 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -45,25 +45,6 @@ include("matrixalgebrakit.jl") include("truncation.jl") include("deprecations.jl") -""" - isposdef(t::AbstractTensor, (leftind, rightind)::Index2Tuple) -> ::Bool - -Test whether a tensor `t` is positive definite as linear map from `rightind` to `leftind`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `isposdef!(t)`. Note that the permuted tensor on -which `isposdef!` is called should have equal domain and codomain, as otherwise it is -meaningless. -""" -function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) - tcopy = permutedcopy_oftype(t, factorisation_scalartype(isposdef, t), p) - return isposdef!(tcopy) -end -function LinearAlgebra.isposdef(t::AbstractTensorMap) - tcopy = copy_oftype(t, float(scalartype(t))) - return isposdef!(tcopy) -end function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) t = permute(t, (p₁, p₂); copy=false) diff --git a/src/tensors/factorizations/interface.jl b/src/tensors/factorizations/interface.jl index 821bc15c3..56d462887 100644 --- a/src/tensors/factorizations/interface.jl +++ b/src/tensors/factorizations/interface.jl @@ -217,9 +217,21 @@ matrices. See the corresponding documentation for more information. See also [`eig(!)`](@ref eig) and [`eigh(!)`](@ref) """ eigen(::AbstractTensorMap), eigen!(::AbstractTensorMap) +@doc """ + isposdef(t::AbstractTensor, [(leftind, rightind)::Index2Tuple]) -> ::Bool + +Test whether a tensor `t` is positive definite as linear map from `rightind` to `leftind`. + +If `leftind` and `rightind` are not specified, the current partition of left and right +indices of `t` is used. In that case, less memory is allocated if one allows the data in +`t` to be destroyed/overwritten, by using `isposdef!(t)`. Note that the permuted tensor on +which `isposdef!` is called should have equal domain and codomain, as otherwise it is +meaningless. +""" isposdef(::AbstractTensorMap), isposdef!(::AbstractTensorMap) + for f in (:tsvd, :eig, :eigh, :eigen, :leftorth, :rightorth, :leftpolar, :rightpolar, :leftnull, - :rightnull) + :rightnull, :isposdef) f! = Symbol(f, :!) @eval function $f(t::AbstractTensorMap, p::Index2Tuple; kwargs...) tcopy = permutedcopy_oftype(t, factorisation_scalartype($f, t), p) From a7ae652360d21b26e23b0af3eb7f828b1a9cda49 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 17:55:30 -0400 Subject: [PATCH 039/126] Fix imports --- ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl | 5 +++++ ext/TensorKitChainRulesCoreExt/factorizations.jl | 2 -- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl index 16c7583d1..98be06676 100644 --- a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -3,6 +3,7 @@ module TensorKitChainRulesCoreExt using TensorOperations using VectorInterface using TensorKit +using TensorKit: foreachblock using ChainRulesCore using LinearAlgebra using TupleTools @@ -11,6 +12,10 @@ import TensorOperations as TO using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract using VectorInterface: promote_scale, promote_add +using MatrixAlgebraKit +using MatrixAlgebraKit: TruncationStrategy, + svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback! + include("utility.jl") include("constructors.jl") include("linalg.jl") diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index 7ba026c95..fe0f31515 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,5 +1,3 @@ -using MatrixAlgebraKit: svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback! - # Factorizations rules # -------------------- function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; From 58cb5235057ca07d9b4c729add631f316c32d54d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 17:55:39 -0400 Subject: [PATCH 040/126] Update tests and fixes Small fixes --- .../factorizations.jl | 8 +++---- src/tensors/factorizations/factorizations.jl | 1 - test/ad.jl | 24 +++++++++---------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index fe0f31515..b80f0b2cd 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -5,7 +5,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; kwargs...) # TODO: I think we can use tsvd! here without issues because we don't actually require # the data of `t` anymore. - USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), alg) + USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), kwargs...) if trunc != TensorKit.notrunc() && !isempty(blocksectors(t)) USVᴴ′ = MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, trunc) @@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; function tsvd!_pullback(ΔUSVᴴ′) ΔUSVᴴ = unthunk.(ΔUSVᴴ′) Δt = similar(t) - foreachblock(Δt) do (c, b) + foreachblock(Δt) do c, (b,) USVᴴc = block.(USVᴴ, Ref(c)) ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c)) svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc) @@ -49,7 +49,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kw function eig!_pullback(ΔDV′) ΔDV = unthunk.(ΔDV′) Δt = similar(t) - foreachblock(Δt) do (c, b) + foreachblock(Δt) do c, (b,) DVc = block.(DV, Ref(c)) ΔDVc = block.(ΔDV, Ref(c)) eig_full_pullback!(b, DVc, ΔDVc) @@ -68,7 +68,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k function eigh!_pullback(ΔDV′) ΔDV = unthunk.(ΔDV′) Δt = similar(t) - foreachblock(Δt) do (c, b) + foreachblock(Δt) do c, (b,) DVc = block.(DV, Ref(c)) ΔDVc = block.(ΔDV, Ref(c)) eigh_full_pullback!(b, DVc, ΔDVc) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 104cd7a4e..2684d3133 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -45,7 +45,6 @@ include("matrixalgebrakit.jl") include("truncation.jl") include("deprecations.jl") - function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) t = permute(t, (p₁, p₂); copy=false) return isisometry(t) diff --git a/test/ad.jl b/test/ad.jl index e5e2d884d..9f5eb2a5b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -398,7 +398,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(eigh′, H; atol, output_tangent=(ΔD, ΔU)) end - let (U, S, V, ϵ) = tsvd(A) + let (U, S, V) = tsvd(A) ΔU = randn(scalartype(U), space(U)) ΔS = randn(scalartype(S), space(S)) ΔV = randn(scalartype(V), space(V)) @@ -408,54 +408,54 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) end end - test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) + test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV)) allS = mapreduce(x -> diag(x[2]), vcat, blocks(S)) truncval = (maximum(allS) + minimum(allS)) / 2 - U, S, V, ϵ = tsvd(A; trunc=truncerr(truncval)) + U, S, V = tsvd(A; trunc=truncerr(truncval)) ΔU = randn(scalartype(U), space(U)) ΔS = randn(scalartype(S), space(S)) ΔV = randn(scalartype(V), space(V)) T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), + test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV), fkwargs=(; trunc=truncerr(truncval))) end - let (U, S, V, ϵ) = tsvd(B) + let (U, S, V) = tsvd(B) ΔU = randn(scalartype(U), space(U)) ΔS = randn(scalartype(S), space(S)) ΔV = randn(scalartype(V), space(V)) T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) + test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV)) Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2) for (c, b) in blocks(S))) - U, S, V, ϵ = tsvd(B; trunc=truncspace(Vtrunc)) + U, S, V = tsvd(B; trunc=truncspace(Vtrunc)) ΔU = randn(scalartype(U), space(U)) ΔS = randn(scalartype(S), space(S)) ΔV = randn(scalartype(V), space(V)) T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), + test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV), fkwargs=(; trunc=truncspace(Vtrunc))) end - let (U, S, V, ϵ) = tsvd(C) + let (U, S, V) = tsvd(C) ΔU = randn(scalartype(U), space(U)) ΔS = randn(scalartype(S), space(S)) ΔV = randn(scalartype(V), space(V)) T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0)) + test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV)) c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) trunc = truncdim(round(Int, 2 * dim(c))) - U, S, V, ϵ = tsvd(C; trunc) + U, S, V = tsvd(C; trunc) ΔU = randn(scalartype(U), space(U)) ΔS = randn(scalartype(S), space(S)) ΔV = randn(scalartype(V), space(V)) T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), fkwargs=(; trunc)) + test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV), fkwargs=(; trunc)) end let D = LinearAlgebra.eigvals(C) From 485239a6801efdfa9a684b2df8dc57ff1e506d3f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 12 Jun 2025 22:59:27 -0400 Subject: [PATCH 041/126] Clean up tests --- test/factorizations.jl | 340 ----------------------------------------- test/paul.jl | 65 -------- 2 files changed, 405 deletions(-) delete mode 100644 test/factorizations.jl delete mode 100644 test/paul.jl diff --git a/test/factorizations.jl b/test/factorizations.jl deleted file mode 100644 index 17642aa93..000000000 --- a/test/factorizations.jl +++ /dev/null @@ -1,340 +0,0 @@ -using TestEnv; -TestEnv.activate(); - -@testsnippet Setup begin - using Test - using TestExtras - using Random - using TensorKit - using Combinatorics - using TensorKit: ProductSector, fusiontensor, pentagon_equation, hexagon_equation - using TensorOperations - using Base.Iterators: take, product - # using SUNRepresentations: SUNIrrep - # const SU3Irrep = SUNIrrep{3} - using LinearAlgebra: LinearAlgebra - using Zygote: Zygote - using MatrixAlgebraKit - - const TK = TensorKit - - Random.seed!(1234) - - smallset(::Type{I}) where {I<:Sector} = take(values(I), 5) - function smallset(::Type{ProductSector{Tuple{I1,I2}}}) where {I1,I2} - iter = product(smallset(I1), smallset(I2)) - s = collect(i ⊠ j for (i, j) in iter if dim(i) * dim(j) <= 6) - return length(s) > 6 ? rand(s, 6) : s - end - function smallset(::Type{ProductSector{Tuple{I1,I2,I3}}}) where {I1,I2,I3} - iter = product(smallset(I1), smallset(I2), smallset(I3)) - s = collect(i ⊠ j ⊠ k for (i, j, k) in iter if dim(i) * dim(j) * dim(k) <= 6) - return length(s) > 6 ? rand(s, 6) : s - end - function randsector(::Type{I}) where {I<:Sector} - s = collect(smallset(I)) - a = rand(s) - while a == one(a) # don't use trivial label - a = rand(s) - end - return a - end - function hasfusiontensor(I::Type{<:Sector}) - try - fusiontensor(one(I), one(I), one(I)) - return true - catch e - if e isa MethodError - return false - else - rethrow(e) - end - end - end - - # spaces - Vtr = (ℂ^3, - (ℂ^4)', - ℂ^5, - ℂ^6, - (ℂ^7)') - Vℤ₂ = (ℂ[Z2Irrep](0 => 1, 1 => 1), - ℂ[Z2Irrep](0 => 1, 1 => 2)', - ℂ[Z2Irrep](0 => 3, 1 => 2)', - ℂ[Z2Irrep](0 => 2, 1 => 3), - ℂ[Z2Irrep](0 => 2, 1 => 5)) - Vfℤ₂ = (ℂ[FermionParity](0 => 1, 1 => 1), - ℂ[FermionParity](0 => 1, 1 => 2)', - ℂ[FermionParity](0 => 3, 1 => 2)', - ℂ[FermionParity](0 => 2, 1 => 3), - ℂ[FermionParity](0 => 2, 1 => 5)) - Vℤ₃ = (ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 2), - ℂ[Z3Irrep](0 => 3, 1 => 1, 2 => 1), - ℂ[Z3Irrep](0 => 2, 1 => 2, 2 => 1)', - ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 3), - ℂ[Z3Irrep](0 => 1, 1 => 3, 2 => 3)') - VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), - ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), - ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', - ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 3), - ℂ[U1Irrep](0 => 1, 1 => 3, -1 => 3)') - VfU₁ = (ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 2), - ℂ[FermionNumber](0 => 3, 1 => 1, -1 => 1), - ℂ[FermionNumber](0 => 2, 1 => 2, -1 => 1)', - ℂ[FermionNumber](0 => 1, 1 => 2, -1 => 3), - ℂ[FermionNumber](0 => 1, 1 => 3, -1 => 3)') - VCU₁ = (ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 2, 1 => 1), - ℂ[CU1Irrep]((0, 0) => 3, (0, 1) => 0, 1 => 1), - ℂ[CU1Irrep]((0, 0) => 1, (0, 1) => 0, 1 => 2)', - ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 2, 1 => 1), - ℂ[CU1Irrep]((0, 0) => 2, (0, 1) => 1, 1 => 2)') - VSU₂ = (ℂ[SU2Irrep](0 => 3, 1 // 2 => 1), - ℂ[SU2Irrep](0 => 2, 1 => 1), - ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', - ℂ[SU2Irrep](0 => 2, 1 // 2 => 2), - ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') - VfSU₂ = (ℂ[FermionSpin](0 => 3, 1 // 2 => 1), - ℂ[FermionSpin](0 => 2, 1 => 1), - ℂ[FermionSpin](1 // 2 => 1, 1 => 1)', - ℂ[FermionSpin](0 => 2, 1 // 2 => 2), - ℂ[FermionSpin](0 => 1, 1 // 2 => 1, 3 // 2 => 1)') - for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) - V1, V2, V3, V4, V5 = V - - @assert V3 * V4 * V2 ≿ V1' * V5' # necessary for leftorth tests - @assert V3 * V4 ≾ V1' * V2' * V5' # necessary for rightorth tests - end - - spacelist = try - if ENV["CI"] == "true" - println("Detected running on CI") - if Sys.iswindows() - (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂) - elseif Sys.isapple() - (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VfU₁, VfSU₂)#, VSU₃) - else - (Vtr, Vℤ₂, Vfℤ₂, VU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) - end - else - (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) - end - catch - (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃) - end -end - -@testitem "left orth" setup = [Setup] begin - function test_leftorth(t, p, alg) - Q, R = @inferred leftorth(t, p; alg) - @test Q * R ≈ permute(t, p) - @test isisometry(Q) - end - - p = ((3, 4, 2), (1, 5)) - elts = (Float32, ComplexF64) - algs = (TensorKit.QR(), TensorKit.QRpos(), TensorKit.QL(), TensorKit.QLpos(), - TensorKit.Polar(), TensorKit.SVD(), TensorKit.SDD()) - - testname(V) = "symmetry: $(TensorKit.type_repr(sectortype(first(V))))" - @timedtestset "$(testname(V))" for V in spacelist - W = ⊗(V...) - for T in elts, alg in algs - t = rand(T, W) - test_leftorth(t, p, alg) - tᴴ = t' - test_leftorth(tᴴ, p, alg) - end - end -end - -function test_leftnull(t, p, alg) - N = @inferred leftnull(t, p; alg) - @test isisometry(N) - @test norm(N' * permute(t, p)) ≈ 0 atol = 100 * eps(norm(t)) -end - -# @timedtestset "Factorizations with symmetry: $(sectortype(first(V)))" for V in spacelist -V = collect(spacelist)[1] -V1, V2, V3, V4, V5 = V -W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 -for T in (Float32, ComplexF64), adj in (false, true) - t = adj ? rand(T, W)' : rand(T, W) - @testset "leftorth with $alg" for alg in - (TensorKit.QR(), TensorKit.QRpos(), TensorKit.QL(), - TensorKit.QLpos(), TensorKit.Polar(), - TensorKit.SVD(), TensorKit.SDD()) - test_leftorth(t, ((3, 4, 2), (1, 5)), alg) - end - @testset "leftnull with $alg" for alg in - (TensorKit.QR(), TensorKit.SVD(), TensorKit.SDD()) - test_leftnull(t, ((3, 4, 2), (1, 5)), alg) - end - @testset "rightorth with $alg" for alg in - (TensorKit.RQ(), TensorKit.RQpos(), - TensorKit.LQ(), TensorKit.LQpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - L, Q = @constinferred rightorth(t, ((3, 4), (2, 1, 5)); alg=alg) - QQd = Q * Q' - @test QQd ≈ one(QQd) - @test L * Q ≈ permute(t, ((3, 4), (2, 1, 5))) - if alg isa Polar - @test isposdef(L) - @test domain(L) == codomain(L) == space(t, 3) ⊗ space(t, 4) - end - end - @testset "rightnull with $alg" for alg in - (TensorKit.LQ(), TensorKit.SVD(), - TensorKit.SDD()) - M = @constinferred rightnull(t, ((3, 4), (2, 1, 5)); alg=alg) - MMd = M * M' - @test MMd ≈ one(MMd) - @test norm(permute(t, ((3, 4), (2, 1, 5))) * M') < - 100 * eps(norm(t)) - end - @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) - U, S, V = @constinferred tsvd(t, ((3, 4, 2), (1, 5)); alg=alg) - UdU = U' * U - @test UdU ≈ one(UdU) - VVd = V * V' - @test VVd ≈ one(VVd) - t2 = permute(t, ((3, 4, 2), (1, 5))) - @test U * S * V ≈ t2 - - s = LinearAlgebra.svdvals(t2) - s′ = LinearAlgebra.diag(S) - for (c, b) in s - @test b ≈ s′[c] - end - end - @testset "cond and rank" begin - t2 = permute(t, ((3, 4, 2), (1, 5))) - d1 = dim(codomain(t2)) - d2 = dim(domain(t2)) - @test rank(t2) == min(d1, d2) - M = leftnull(t2) - @test rank(M) == max(d1, d2) - min(d1, d2) - t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) - @test cond(t3) ≈ one(real(T)) - @test rank(t3) == dim(V1 ⊗ V2) - t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) - t4 = (t4 + t4') / 2 - vals = LinearAlgebra.eigvals(t4) - λmax = maximum(s -> maximum(abs, s), values(vals)) - λmin = minimum(s -> minimum(abs, s), values(vals)) - @test cond(t4) ≈ λmax / λmin - end -end - -@testset "empty tensor" begin - for T in (Float32, ComplexF64) - T = Float64 - t = randn(T, V1 ⊗ V2, zero(V1)) - @testset "leftorth with $alg" for alg in - (TensorKit.QR(), TensorKit.QRpos(), - TensorKit.QL(), TensorKit.QLpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - Q, R = @constinferred leftorth(t; alg=alg) - @test Q == t - @test dim(Q) == dim(R) == 0 - end - @testset "leftnull with $alg" for alg in - (TensorKit.QR(), TensorKit.SVD(), - TensorKit.SDD()) - N = @constinferred leftnull(t; alg=alg) - @test N' * N ≈ id(domain(N)) - @test N * N' ≈ id(codomain(N)) - end - @testset "rightorth with $alg" for alg in - (TensorKit.RQ(), TensorKit.RQpos(), - TensorKit.LQ(), TensorKit.LQpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - L, Q = @constinferred rightorth(copy(t'); alg=alg) - @test Q == t' - @test dim(Q) == dim(L) == 0 - end - @testset "rightnull with $alg" for alg in - (TensorKit.LQ(), TensorKit.SVD(), - TensorKit.SDD()) - M = @constinferred rightnull(copy(t'); alg=alg) - @test M * M' ≈ id(codomain(M)) - @test M' * M ≈ id(domain(M)) - end - @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) - U, S, V = @constinferred tsvd(t; alg=alg) - @test U == t - @test dim(U) == dim(S) == dim(V) - end - @testset "cond and rank" begin - @test rank(t) == 0 - W2 = zero(V1) * zero(V2) - t2 = rand(W2, W2) - @test rank(t2) == 0 - @test cond(t2) == 0.0 - end - end -end -@testset "eig and isposdef" begin - for T in (Float32, ComplexF64) - t = rand(T, V1 ⊗ V1' ⊗ V2 ⊗ V2') - D, V = eigen(t, ((1, 3), (2, 4))) - t2 = permute(t, ((1, 3), (2, 4))) - @test t2 * V ≈ V * D - - d = LinearAlgebra.eigvals(t2; sortby=nothing) - d′ = LinearAlgebra.diag(D) - for (c, b) in d - @test b ≈ d′[c] - end - - # Somehow moving these test before the previous one gives rise to errors - # with T=Float32 on x86 platforms. Is this an OpenBLAS issue? - VdV = V' * V - VdV = (VdV + VdV') / 2 - @test isposdef(VdV) - - @test !isposdef(t2) # unlikely for non-hermitian map - t2 = (t2 + t2') - D, V = eigen(t2) - VdV = V' * V - @test VdV ≈ one(VdV) - D̃, Ṽ = @constinferred eigh(t2) - @test D ≈ D̃ - @test V ≈ Ṽ - λ = minimum(minimum(real(LinearAlgebra.diag(b))) - for (c, b) in blocks(D)) - @test cond(Ṽ) ≈ one(real(T)) - @test isposdef(t2) == isposdef(λ) - @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) - @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) - end -end -@testset "Tensor truncation" begin - for T in (Float32, ComplexF64), p in (1, 2, 3, Inf), adj in (false, true) - t = adj ? rand(T, V1 ⊗ V2 ⊗ V3, V4 ⊗ V5) : rand(T, V4 ⊗ V5, V1 ⊗ V2 ⊗ V3)' - - U₀, S₀, V₀, = tsvd(t) - t = rmul!(t, 1 / norm(S₀, p)) - U, S, V, ϵ = @constinferred tsvd(t; trunc=truncerr(5e-1), p=p) - # @show p, ϵ - # @show domain(S) - # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncerr(nextfloat(ϵ)), p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S)))), - p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - # results with truncationcutoff cannot be compared because they don't take degeneracy into account, and thus truncate differently - U, S, V, ϵ = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀))), p=p) - # @show p, ϵ - # @show domain(S) - # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′, ϵ′ = tsvd(t; trunc=truncspace(space(S, 1)), p=p) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - end -end -# end diff --git a/test/paul.jl b/test/paul.jl deleted file mode 100644 index 249ed1bae..000000000 --- a/test/paul.jl +++ /dev/null @@ -1,65 +0,0 @@ -using Zygote, TensorKit - -_safe_pow(a::Real, pow::Real, tol::Real) = (pow < 0 && abs(a) < tol) ? zero(a) : a^pow - -# Element-wise multiplication of TensorMaps respecting block structure -function _elementwise_mult(a₁::AbstractTensorMap, a₂::AbstractTensorMap) - dst = similar(a₁) - for (k, b) in blocks(dst) - copyto!(b, block(a₁, k) .* block(a₂, k)) - end - return dst -end -""" - sdiag_pow(s, pow::Real; tol::Real=eps(scalartype(s))^(3 / 4)) - -Compute `s^pow` for a diagonal matrix `s`. -""" -function sdiag_pow(s::DiagonalTensorMap, pow::Real; tol::Real=eps(scalartype(s))^(3 / 4)) - # Relative tol w.r.t. largest singular value (use norm(∘, Inf) to make differentiable) - tol *= norm(s, Inf) - spow = DiagonalTensorMap(_safe_pow.(s.data, pow, tol), space(s, 1)) - return spow -end -function sdiag_pow(s::AbstractTensorMap{T,S,1,1}, pow::Real; - tol::Real=eps(scalartype(s))^(3 / 4)) where {T,S} - # Relative tol w.r.t. largest singular value (use norm(∘, Inf) to make differentiable) - tol *= norm(s, Inf) - spow = similar(s) - for (k, b) in blocks(s) - copyto!(block(spow, k), - LinearAlgebra.diagm(_safe_pow.(LinearAlgebra.diag(b), pow, tol))) - end - return spow -end - -function ChainRulesCore.rrule(::typeof(sdiag_pow), - s::AbstractTensorMap, - pow::Real; - tol::Real=eps(scalartype(s))^(3 / 4),) - tol *= norm(s, Inf) - spow = sdiag_pow(s, pow; tol) - spow_minus1_conj = scale!(sdiag_pow(s', pow - 1; tol), pow) - function sdiag_pow_pullback(c̄_) - c̄ = unthunk(c̄_) - return (ChainRulesCore.NoTangent(), _elementwise_mult(c̄, spow_minus1_conj)) - end - return spow, sdiag_pow_pullback -end - -function svd_fixed_point(A, U, S, V) - S⁻¹ = sdiag_pow(S, -1) - return (A * V' * S⁻¹ - U, DiagonalTensorMap(U' * A * V' * S⁻¹) - one(S), - S⁻¹ * U' * A - V) -end - -using Zygote - -V = ComplexSpace(3)^2 -A = randn(ComplexF64, V, V) -U, S, V = tsvd(A) - -Zygote.gradient(A, U, S, V) do A, U, S, V - du, ds, dv = svd_fixed_point(A, U, S, V) - return norm(du) + norm(ds) + norm(dv) -end From f31dcd54029b98553f543500af85292179f064b4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 15 Jun 2025 13:22:07 -0400 Subject: [PATCH 042/126] Bump minimal MatrixAlgebraKit version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 474c57a10..6843ce3db 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.2" +MatrixAlgebraKit = "0.2.5" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" From b99fec33e601167e9afbbe0951de722326226265 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 15 Jun 2025 13:40:43 -0400 Subject: [PATCH 043/126] Fix uninitialized cotangents --- ext/TensorKitChainRulesCoreExt/factorizations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index b80f0b2cd..de4dcbded 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -15,7 +15,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; function tsvd!_pullback(ΔUSVᴴ′) ΔUSVᴴ = unthunk.(ΔUSVᴴ′) - Δt = similar(t) + Δt = zerovector(t) foreachblock(Δt) do c, (b,) USVᴴc = block.(USVᴴ, Ref(c)) ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c)) @@ -48,7 +48,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kw function eig!_pullback(ΔDV′) ΔDV = unthunk.(ΔDV′) - Δt = similar(t) + Δt = zerovector(t) foreachblock(Δt) do c, (b,) DVc = block.(DV, Ref(c)) ΔDVc = block.(ΔDV, Ref(c)) @@ -67,7 +67,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k function eigh!_pullback(ΔDV′) ΔDV = unthunk.(ΔDV′) - Δt = similar(t) + Δt = zerovector(t) foreachblock(Δt) do c, (b,) DVc = block.(DV, Ref(c)) ΔDVc = block.(ΔDV, Ref(c)) From 5c9a79bc4e1744133ee870b56b84fabe3d89cb9b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 15 Jun 2025 13:56:57 -0400 Subject: [PATCH 044/126] Update and use `MatrixAlgebraKit.isisometry` Correctly implement `isisometry` --- src/TensorKit.jl | 2 +- src/auxiliary/linalg.jl | 2 - src/tensors/factorizations/factorizations.jl | 13 +++--- test/tensors.jl | 42 +++++--------------- 4 files changed, 20 insertions(+), 39 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 7eb3652ec..b087de47e 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -73,7 +73,7 @@ export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby! export leftorth, rightorth, leftnull, rightnull, leftpolar, rightpolar, leftorth!, rightorth!, leftnull!, rightnull!, leftpolar!, rightpolar!, tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, - isposdef, isposdef!, ishermitian, isisometry, sylvester, rank, cond + isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! export catdomain, catcodomain, absorb, absorb! diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index fa6e5e248..82e8600f0 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -84,8 +84,6 @@ end safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s)) safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) -isisometry(A::StridedMatrix; kwargs...) = isapprox(A' * A, LinearAlgebra.I, kwargs...) - function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real) iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) m, n = size(A) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 2684d3133..864087d88 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -184,12 +184,15 @@ function LinearAlgebra.isposdef!(t::TensorMap) end # TODO: tolerances are per-block, not global or weighted - does that matter? -function isisometry(t::AbstractTensorMap; kwargs...) +function MatrixAlgebraKit.is_left_isometry(t::AbstractTensorMap; kwargs...) domain(t) ≾ codomain(t) || return false - for (_, b) in blocks(t) - MatrixAlgebra.isisometry(b; kwargs...) || return false - end - return true + f((c, b)) = MatrixAlgebraKit.is_left_isometry(b; kwargs...) + return all(f, blocks(t)) +end +function MatrixAlgebraKit.is_right_isometry(t::AbstractTensorMap; kwargs...) + domain(t) ≿ codomain(t) || return false + f((c, b)) = MatrixAlgebraKit.is_right_isometry(b; kwargs...) + return all(f, blocks(t)) end end diff --git a/test/tensors.jl b/test/tensors.jl index 4280348fb..5d6980ddb 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -369,9 +369,8 @@ for V in spacelist for T in (Float64, ComplexF64) t1 = randisometry(T, W1, W2) t2 = randisometry(T, W2 ← W2) - @test t1' * t1 ≈ one(t2) - @test t2' * t2 ≈ one(t2) - @test t2 * t2' ≈ one(t2) + @test isisometry(t1) + @test isunitary(t2) P = t1 * t1' @test P * P ≈ P end @@ -451,21 +450,14 @@ for V in spacelist TensorKit.Polar(), TensorKit.SVD(), TensorKit.SDD()) Q, R = @constinferred leftorth(t, ((3, 4, 2), (1, 5)); alg=alg) - QdQ = Q' * Q - @test QdQ ≈ one(QdQ) + @test isisometry(Q) @test Q * R ≈ permute(t, ((3, 4, 2), (1, 5))) - # removed since leftorth now merges legs! - # if alg isa Polar - # @test isposdef(R) - # @test domain(R) == codomain(R) == space(t, 1)' ⊗ space(t, 5)' - # end end @testset "leftnull with $alg" for alg in (TensorKit.QR(), TensorKit.SVD(), TensorKit.SDD()) N = @constinferred leftnull(t, ((3, 4, 2), (1, 5)); alg=alg) - NdN = N' * N - @test NdN ≈ one(NdN) + @test isisometry(N) @test norm(N' * permute(t, ((3, 4, 2), (1, 5)))) < 100 * eps(norm(t)) end @@ -475,30 +467,21 @@ for V in spacelist TensorKit.Polar(), TensorKit.SVD(), TensorKit.SDD()) L, Q = @constinferred rightorth(t, ((3, 4), (2, 1, 5)); alg=alg) - QQd = Q * Q' - @test QQd ≈ one(QQd) + @test isisometry(Q; side=:right) @test L * Q ≈ permute(t, ((3, 4), (2, 1, 5))) - # removed since rightorth now merges legs! - # if alg isa Polar - # @test isposdef(L) - # @test domain(L) == codomain(L) == space(t, 3) ⊗ space(t, 4) - # end end @testset "rightnull with $alg" for alg in (TensorKit.LQ(), TensorKit.SVD(), TensorKit.SDD()) M = @constinferred rightnull(t, ((3, 4), (2, 1, 5)); alg=alg) - MMd = M * M' - @test MMd ≈ one(MMd) + @test isisometry(M; side=:right) @test norm(permute(t, ((3, 4), (2, 1, 5))) * M') < 100 * eps(norm(t)) end @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) U, S, V = @constinferred tsvd(t, ((3, 4, 2), (1, 5)); alg=alg) - UdU = U' * U - @test UdU ≈ one(UdU) - VVd = V * V' - @test VVd ≈ one(VVd) + @test isisometry(U) + @test isisometry(V; side=:right) t2 = permute(t, ((3, 4, 2), (1, 5))) @test U * S * V ≈ t2 @@ -541,8 +524,7 @@ for V in spacelist (TensorKit.QR(), TensorKit.SVD(), TensorKit.SDD()) N = @constinferred leftnull(t; alg=alg) - @test N' * N ≈ id(domain(N)) - @test N * N' ≈ id(codomain(N)) + @test isunitary(N) end @testset "rightorth with $alg" for alg in (TensorKit.RQ(), TensorKit.RQpos(), @@ -557,8 +539,7 @@ for V in spacelist (TensorKit.LQ(), TensorKit.SVD(), TensorKit.SDD()) M = @constinferred rightnull(copy(t'); alg=alg) - @test M * M' ≈ id(codomain(M)) - @test M' * M ≈ id(domain(M)) + @test isunitary(M) end @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) U, S, V = @constinferred tsvd(t; alg=alg) @@ -594,8 +575,7 @@ for V in spacelist @test !isposdef(t2) # unlikely for non-hermitian map t2 = (t2 + t2') D, V = eigen(t2) - VdV = V' * V - @test VdV ≈ one(VdV) + @test isisometry(V) D̃, Ṽ = @constinferred eigh(t2) @test D ≈ D̃ @test V ≈ Ṽ From 61a1831691ea666f40ff56ecd76b0038cbb82b58 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Jun 2025 19:58:32 -0400 Subject: [PATCH 045/126] Fix missing export --- src/tensors/factorizations/factorizations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 864087d88..95683b0d9 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -4,7 +4,7 @@ module Factorizations export eig, eig!, eigh, eigh! -export tsvd, tsvd!, svdvals +export tsvd, tsvd!, svdvals, svdvals! export leftorth, leftorth!, rightorth, rightorth! export leftnull, leftnull!, rightnull, rightnull! export leftpolar, leftpolar!, rightpolar, rightpolar! @@ -15,7 +15,7 @@ using ..TensorKit using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock using ..MatrixAlgebra: MatrixAlgebra -using LinearAlgebra: LinearAlgebra, BlasFloat +using LinearAlgebra: LinearAlgebra, BlasFloat, svdvals, svdvals! import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian using TensorOperations: Index2Tuple From 5644e7aa71d73258af8827158f77ba5030fa388d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Jun 2025 20:11:46 -0400 Subject: [PATCH 046/126] Implement remaining factorization rrules --- .../TensorKitChainRulesCoreExt.jl | 3 +- .../factorizations.jl | 175 +++--------------- 2 files changed, 25 insertions(+), 153 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl index 98be06676..36bb4108d 100644 --- a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -14,7 +14,8 @@ using VectorInterface: promote_scale, promote_add using MatrixAlgebraKit using MatrixAlgebraKit: TruncationStrategy, - svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback! + svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!, + qr_compact_pullback!, lq_compact_pullback! include("utility.jl") include("constructors.jl") diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index de4dcbded..95a62b24f 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -99,167 +99,38 @@ end function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) alg isa TensorKit.QR || alg isa TensorKit.QRpos || error("only `alg=QR()` and `alg=QRpos()` are supported") - Q, R = leftorth(t; alg) - function leftorth!_pullback((_ΔQ, _ΔR)) - ΔQ, ΔR = unthunk(_ΔQ), unthunk(_ΔR) - Δt = similar(t) - for (c, b) in blocks(Δt) - qr_pullback!(b, block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c)) + QR = leftorth(t; alg) + function leftorth!_pullback(ΔQR′) + ΔQR = unthunk.(ΔQR′) + Δt = zerovector(t) + foreachblock(Δt) do c, (b,) + QRc = block.(QR, Ref(c)) + ΔQRc = block.(ΔQR, Ref(c)) + qr_compact_pullback!(b, QRc, ΔQRc) + return nothing end return NoTangent(), Δt end - leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent() - return (Q, R), leftorth!_pullback + leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() + + return QR, leftorth!_pullback end function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) alg isa TensorKit.LQ || alg isa TensorKit.LQpos || error("only `alg=LQ()` and `alg=LQpos()` are supported") - L, Q = rightorth(t; alg) - function rightorth!_pullback((_ΔL, _ΔQ)) - ΔL, ΔQ = unthunk(_ΔL), unthunk(_ΔQ) - Δt = similar(t) - for (c, b) in blocks(Δt) - lq_pullback!(b, block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c)) + LQ = rightorth(t; alg) + function rightorth!_pullback(ΔLQ′) + ΔLQ = unthunk(ΔLQ′) + Δt = zerovector(t) + foreachblock(Δt) do c, (b,) + LQc = block.(LQ, Ref(c)) + ΔLQc = block.(ΔLQ, Ref(c)) + lq_compact_pullback!(b, LQc, ΔLQc) + return nothing end return NoTangent(), Δt end - rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent() - return (L, Q), rightorth!_pullback -end - -# Corresponding matrix factorisations: implemented as mutating methods -# --------------------------------------------------------------------- -# helper routines -safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a) - -function lowertriangularind(A::AbstractMatrix) - m, n = size(A) - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) - offset = 0 - for j in 1:n - r = (j + 1):m - I[offset .- j .+ r] = (j - 1) * m .+ r - offset += length(r) - end - return I -end - -function uppertriangularind(A::AbstractMatrix) - m, n = size(A) - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) - offset = 0 - for i in 1:m - r = (i + 1):n - I[offset .- i .+ r] = i .+ m .* (r .- 1) - offset += length(r) - end - return I -end - -function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR; - tol::Real=default_pullback_gaugetol(R)) - Rd = view(R, diagind(R)) - p = something(findlast(≥(tol) ∘ abs, Rd), 0) - m, n = size(R) - - Q1 = view(Q, :, 1:p) - R1 = view(R, 1:p, :) - R11 = view(R, 1:p, 1:p) - - ΔA1 = view(ΔA, :, 1:p) - ΔQ1 = view(ΔQ, :, 1:p) - ΔR1 = view(ΔR, 1:p, :) - - M = similar(R, (p, p)) - ΔR isa AbstractZero || mul!(M, ΔR1, R1') - ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero)) - view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) - if eltype(M) <: Complex - Md = view(M, diagind(M)) - Md .= real.(Md) - end - - ΔA1 .= ΔQ1 - mul!(ΔA1, Q1, M, +1, 1) - - if n > p - R12 = view(R, 1:p, (p + 1):n) - ΔA2 = view(ΔA, :, (p + 1):n) - ΔR12 = view(ΔR, 1:p, (p + 1):n) - - if ΔR isa AbstractZero - ΔA2 .= zero(eltype(ΔA)) - else - mul!(ΔA2, Q1, ΔR12) - mul!(ΔA1, ΔA2, R12', -1, 1) - end - end - if m > p && !(ΔQ isa AbstractZero) # case where R is not full rank - Q2 = view(Q, :, (p + 1):m) - ΔQ2 = view(ΔQ, :, (p + 1):m) - Q1dΔQ2 = Q1' * ΔQ2 - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) - Δgauge < tol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - mul!(ΔA1, Q2, Q1dΔQ2', -1, 1) - end - rdiv!(ΔA1, UpperTriangular(R11)') - return ΔA -end - -function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ; - tol::Real=default_pullback_gaugetol(L)) - Ld = view(L, diagind(L)) - p = something(findlast(≥(tol) ∘ abs, Ld), 0) - m, n = size(L) - - L1 = view(L, :, 1:p) - L11 = view(L, 1:p, 1:p) - Q1 = view(Q, 1:p, :) - - ΔA1 = view(ΔA, 1:p, :) - ΔQ1 = view(ΔQ, 1:p, :) - ΔL1 = view(ΔL, :, 1:p) - - M = similar(L, (p, p)) - ΔL isa AbstractZero || mul!(M, L1', ΔL1) - ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero)) - view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) - if eltype(M) <: Complex - Md = view(M, diagind(M)) - Md .= real.(Md) - end - - ΔA1 .= ΔQ1 - mul!(ΔA1, M, Q1, +1, 1) - - if m > p - L21 = view(L, (p + 1):m, 1:p) - ΔA2 = view(ΔA, (p + 1):m, :) - ΔL21 = view(ΔL, (p + 1):m, 1:p) - - if ΔL isa AbstractZero - ΔA2 .= zero(eltype(ΔA)) - else - mul!(ΔA2, ΔL21, Q1) - mul!(ΔA1, L21', ΔA2, -1, 1) - end - end - if n > p && !(ΔQ isa AbstractZero) # case where R is not full rank - Q2 = view(Q, (p + 1):n, :) - ΔQ2 = view(ΔQ, (p + 1):n, :) - ΔQ2Q1d = ΔQ2 * Q1' - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1)) - Δgauge < tol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1) - end - ldiv!(LowerTriangular(L11)', ΔA1) - return ΔA -end - -function default_pullback_gaugetol(a) - n = norm(a, Inf) - return eps(eltype(n))^(3 / 4) * max(n, one(n)) + rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() + return LQ, rightorth!_pullback end From 94550da0f4296b1b18d6c686f04865f0729f50e4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Jun 2025 20:45:30 -0400 Subject: [PATCH 047/126] Implement truncated eigenvalues --- .../factorizations.jl | 92 ++++++------------- src/tensors/factorizations/implementations.jl | 24 ++++- .../factorizations/matrixalgebrakit.jl | 20 ++++ src/tensors/factorizations/truncation.jl | 61 +++++++++++- 4 files changed, 131 insertions(+), 66 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index 95a62b24f..0c6924411 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,32 +1,38 @@ # Factorizations rules # -------------------- -function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; - trunc::TruncationStrategy=TensorKit.notrunc(), - kwargs...) - # TODO: I think we can use tsvd! here without issues because we don't actually require - # the data of `t` anymore. - USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), kwargs...) - - if trunc != TensorKit.notrunc() && !isempty(blocksectors(t)) - USVᴴ′ = MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, trunc) - else - USVᴴ′ = USVᴴ - end +for f in (:tsvd, :eig, :eigh) + f! = Symbol(f, :!) + f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!) + f_pullback = Symbol(f, :_pullback) + f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!) + @eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap; + trunc::TruncationStrategy=TensorKit.notrunc(), + kwargs...) + # TODO: I think we can use f! here without issues because we don't actually require + # the data of `t` anymore. + F = $f(t; trunc=TensorKit.notrunc(), kwargs...) + + if trunc != TensorKit.notrunc() && !isempty(blocksectors(t)) + F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc) + else + F′ = F + end - function tsvd!_pullback(ΔUSVᴴ′) - ΔUSVᴴ = unthunk.(ΔUSVᴴ′) - Δt = zerovector(t) - foreachblock(Δt) do c, (b,) - USVᴴc = block.(USVᴴ, Ref(c)) - ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c)) - svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc) - return nothing + function $f_pullback(ΔF′) + ΔF = unthunk.(ΔF′) + Δt = zerovector(t) + foreachblock(Δt) do c, (b,) + Fc = block.(F, Ref(c)) + ΔFc = block.(ΔF, Ref(c)) + $f_pullback!(b, Fc, ΔFc) + return nothing + end + return NoTangent(), Δt end - return NoTangent(), Δt - end - tsvd!_pullback(::NTuple{3,ZeroTangent}) = NoTangent(), ZeroTangent() + $f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent() - return USVᴴ′, tsvd!_pullback + return F′, $f_pullback + end end function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap) @@ -43,44 +49,6 @@ function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTenso return s, svdvals_pullback end -function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...) - DV = eig(t; kwargs...) - - function eig!_pullback(ΔDV′) - ΔDV = unthunk.(ΔDV′) - Δt = zerovector(t) - foreachblock(Δt) do c, (b,) - DVc = block.(DV, Ref(c)) - ΔDVc = block.(ΔDV, Ref(c)) - eig_full_pullback!(b, DVc, ΔDVc) - return nothing - end - return NoTangent(), Δt - end - eig!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() - - return DV, eig!_pullback -end - -function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...) - DV = eigh(t; kwargs...) - - function eigh!_pullback(ΔDV′) - ΔDV = unthunk.(ΔDV′) - Δt = zerovector(t) - foreachblock(Δt) do c, (b,) - DVc = block.(DV, Ref(c)) - ΔDVc = block.(ΔDV, Ref(c)) - eigh_full_pullback!(b, DVc, ΔDVc) - return nothing - end - return NoTangent(), Δt - end - eigh!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() - - return DV, eigh!_pullback -end - function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap; sortby=nothing, kwargs...) @assert sortby === nothing "only `sortby=nothing` is supported" diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index 89c0d0a00..56d816258 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -149,9 +149,27 @@ rightpolar!(t::AbstractTensorMap; kwargs...) = right_polar!(t; kwargs...) # Eigenvalue decomposition # ------------------------ -eigh!(t::AbstractTensorMap) = eigh_full!(t) -eig!(t::AbstractTensorMap) = eig_full!(t) -eigen!(t::AbstractTensorMap) = ishermitian(t) ? eigh!(t) : eig!(t) +function eigh!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) + if trunc == notrunc() + return eigh_full!(t; kwargs...) + else + return eigh_trunc!(t; trunc, kwargs...) + end +end + +function eig!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eig!) + if trunc == notrunc() + return eig_full!(t; kwargs...) + else + return eig_trunc!(t; trunc, kwargs...) + end +end + +function eigen!(t::AbstractTensorMap; kwargs...) + return ishermitian(t) ? eigh!(t; kwargs...) : eig!(t; kwargs...) +end # Singular value decomposition # ---------------------------- diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 605428f84..6d34a1e20 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -193,6 +193,26 @@ function initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::Abstract return D, V end +function initialize_output(::typeof(eigh_trunc!), t::AbstractTensorMap, + alg::TruncatedAlgorithm) + return initialize_output(eigh_full!, t, alg.alg) +end + +function initialize_output(::typeof(eig_trunc!), t::AbstractTensorMap, + alg::TruncatedAlgorithm) + return initialize_output(eig_full!, t, alg.alg) +end + +function eigh_trunc!(t::AbstractTensorMap, DV, alg::TruncatedAlgorithm) + DV′ = eigh_full!(t, DV, alg.alg) + return truncate!(eigh_trunc!, DV′, alg.trunc) +end + +function eig_trunc!(t::AbstractTensorMap, DV, alg::TruncatedAlgorithm) + DV′ = eig_full!(t, DV, alg.alg) + return truncate!(eig_trunc!, DV′, alg.trunc) +end + # QR decomposition # ---------------- const _T_QR = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 4f7cec733..e172f17c6 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -65,6 +65,47 @@ function truncate!(::typeof(left_null!), return Ũ end +function truncate!(::typeof(eigh_trunc!), (D, V)::_T_DV, strategy::TruncationStrategy) + ind = findtruncated(diagview(D), strategy) + V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) + + D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) + for (c, b) in blocks(D̃) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b.diag, @view(block(D, c).diag[I])) + end + + Ṽ = similar(V, V_truncated ← domain(V)) + for (c, b) in blocks(Ṽ) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b, @view(block(V, c)[I, :])) + end + + return D̃, Ṽ +end +function truncate!(::typeof(eig_trunc!), (D, V)::_T_DV, strategy::TruncationStrategy) + ind = findtruncated(diagview(D), strategy) + V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) + + D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) + for (c, b) in blocks(D̃) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b.diag, @view(block(D, c).diag[I])) + end + + Ṽ = similar(V, V_truncated ← domain(V)) + for (c, b) in blocks(Ṽ) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b, @view(block(V, c)[I, :])) + end + + return D̃, Ṽ +end + # Find truncation # --------------- # auxiliary functions @@ -88,18 +129,28 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector} return σmin, keys(truncdim)[imin] end -# sorted implementations +# implementations function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove) atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol)) return SectorDict(c => findtrunc(d) for (c, d) in Sd) end +function findtruncated(S::SectorDict, strategy::TruncationKeepAbove) + atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) + findtrunc = Base.Fix2(findtruncated, truncbelow(atol)) + return SectorDict(c => findtrunc(d) for (c, d) in Sd) +end function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepBelow) atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) findtrunc = Base.Fix2(findtruncated_sorted, truncabove(atol)) return SectorDict(c => findtrunc(d) for (c, d) in Sd) end +function findtruncated(S::SectorDict, strategy::TruncationKeepBelow) + atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) + findtrunc = Base.Fix2(findtruncated, truncabove(atol)) + return SectorDict(c => findtrunc(d) for (c, d) in Sd) +end function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError) I = keytype(Sd) @@ -153,9 +204,17 @@ end function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepFiltered) return SectorDict(c => findtruncated_sorted(d, strategy) for (c, d) in Sd) end +function findtruncated(Sd::SectorDict, strategy::TruncationKeepFiltered) + return SectorDict(c => findtruncated(d, strategy) for (c, d) in Sd) +end function findtruncated_sorted(Sd::SectorDict, strategy::TruncationIntersection) inds = map(Base.Fix1(findtruncated_sorted, Sd), strategy) return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...) for c in intersect(map(keys, inds)...)) end +function findtruncated(Sd::SectorDict, strategy::TruncationIntersection) + inds = map(Base.Fix1(findtruncated, Sd), strategy) + return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...) + for c in intersect(map(keys, inds)...)) +end From 24748d22db52b7fede18dac85a8b19cc3835d20f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Jun 2025 21:22:54 -0400 Subject: [PATCH 048/126] Implement `TruncationKeepSorted` --- src/tensors/factorizations/truncation.jl | 31 ++++++++++++++++++------ 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index e172f17c6..d553c5c93 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -119,14 +119,23 @@ function _compute_truncerr(Σdata, truncdim, p=2) p, zero(S)) end -function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector} +function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}; by=identity, + rev::Bool=true) where {I<:Sector} # early return (isempty(S) || all(iszero, values(truncdim))) && return nothing - σmin, imin = findmin(keys(truncdim)) do c - d = truncdim[c] - return S[c][d] + if rev + σmin, imin = findmin(keys(truncdim)) do c + d = truncdim[c] + return by(S[c][d]) + end + return σmin, keys(truncdim)[imin] + else + σmax, imax = findmax(keys(truncdim)) do c + d = truncdim[c] + return by(S[c][d]) + end + return σmax, keys(truncdim)[imax] end - return σmin, keys(truncdim)[imin] end # implementations @@ -173,12 +182,18 @@ function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError) end function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted) - @assert strategy.by === abs && strategy.rev == true "Not implemented" + return findtruncated(Sd, strategy) +end +function findtruncated(Sd::SectorDict, strategy::TruncationKeepSorted) + permutations = SectorDict(c => (sortperm(d; strategy.by, strategy.rev)) + for (c, d) in Sd) + Sd = SectorDict(c => sort(d; strategy.by, strategy.rev) for (c, d) in Sd) + I = keytype(Sd) truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd) totaldim = sum(dim(c) * d for (c, d) in truncdim; init=0) while true - next = _findnexttruncvalue(Sd, truncdim) + next = _findnexttruncvalue(Sd, truncdim; strategy.by, strategy.rev) isnothing(next) && break _, cmin = next truncdim[cmin] -= 1 @@ -191,7 +206,7 @@ function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted) delete!(truncdim, cmin) end end - return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) + return SectorDict(c => permutations[c][Base.OneTo(d)] for (c, d) in truncdim) end function findtruncated_sorted(Sd::SectorDict, strategy::TruncationSpace) From f9a22f29ea7c0b7e2c27950ff51efd4042fb239c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 30 Jul 2025 15:38:22 -0400 Subject: [PATCH 049/126] correctly restrict type --- src/tensors/factorizations/implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index 56d816258..7edf518f2 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -5,7 +5,7 @@ _kindof(::Polar) = :polar for f! in (:svd_compact!, :svd_full!, :left_null_svd!, :right_null_svd!) @eval function select_algorithm(::typeof($f!), t::T, alg::SVD; - kwargs...) where {T} + kwargs...) where {T<:AbstractTensorMap} isempty(kwargs) || throw(ArgumentError("Additional keyword arguments are not allowed")) return LAPACK_QRIteration() From c17d7bbe62e64f7ed5b24b742330d30b5064c413 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 30 Jul 2025 15:40:59 -0400 Subject: [PATCH 050/126] canonical use of codomain as first arg --- src/tensors/factorizations/matrixalgebrakit.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 6d34a1e20..be6d8438f 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -383,7 +383,7 @@ function initialize_output(::typeof(left_polar!), t::AbstractTensorMap, ::Abstra end function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ) - domain(t) ≿ codomain(t) || + codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) # scalartype checks @@ -398,7 +398,7 @@ function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T end function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ) - domain(t) ≿ codomain(t) || + codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) # scalartype checks From 6dd26b872b183ceae6e46ff0c4cd8d435969e224 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 30 Jul 2025 15:44:06 -0400 Subject: [PATCH 051/126] import Diagonal --- src/tensors/factorizations/factorizations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 95683b0d9..8fd66b8b6 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -15,7 +15,7 @@ using ..TensorKit using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock using ..MatrixAlgebra: MatrixAlgebra -using LinearAlgebra: LinearAlgebra, BlasFloat, svdvals, svdvals! +using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals! import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian using TensorOperations: Index2Tuple From 0f02734a7d32db13ff488b1acedec8e427d4cfa3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 6 Aug 2025 15:36:53 -0400 Subject: [PATCH 052/126] rework algorithm selection using updated matrixalgebrakit --- src/tensors/factorizations/factorizations.jl | 5 +-- src/tensors/factorizations/implementations.jl | 39 ------------------- .../factorizations/matrixalgebrakit.jl | 29 ++++---------- 3 files changed, 8 insertions(+), 65 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 8fd66b8b6..df9e0bdc5 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -24,10 +24,7 @@ using MatrixAlgebraKit using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy, NoTruncation, TruncationKeepAbove, TruncationKeepBelow, TruncationIntersection, TruncationKeepFiltered -import MatrixAlgebraKit: select_algorithm, - default_qr_algorithm, default_lq_algorithm, - default_eig_algorithm, default_eigh_algorithm, - default_svd_algorithm, default_polar_algorithm, +import MatrixAlgebraKit: default_algorithm, copy_input, check_input, initialize_output, qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, svd_compact!, svd_full!, svd_trunc!, diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index 7edf518f2..a53fd1680 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -3,45 +3,6 @@ _kindof(::Union{QR,QRpos}) = :qr _kindof(::Union{LQ,LQpos}) = :lq _kindof(::Polar) = :polar -for f! in (:svd_compact!, :svd_full!, :left_null_svd!, :right_null_svd!) - @eval function select_algorithm(::typeof($f!), t::T, alg::SVD; - kwargs...) where {T<:AbstractTensorMap} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed")) - return LAPACK_QRIteration() - end - @eval function select_algorithm(::typeof($f!), t::AbstractTensorMap, alg::SVD; - kwargs...) - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed")) - return LAPACK_QRIteration() - end - @eval function select_algorithm(::typeof($f!), ::Type{T}, alg::SVD; - kwargs...) where {T<:AbstractTensorMap} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed")) - return LAPACK_QRIteration() - end - @eval function select_algorithm(::typeof($f!), t::T, alg::SDD; - kwargs...) where {T} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed")) - return LAPACK_DivideAndConquer() - end - @eval function select_algorithm(::typeof($f!), t::AbstractTensorMap, alg::SDD; - kwargs...) - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed")) - return LAPACK_DivideAndConquer() - end - @eval function select_algorithm(::typeof($f!), ::Type{T}, alg::SDD; - kwargs...) where {T<:AbstractTensorMap} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed")) - return LAPACK_DivideAndConquer() - end -end - leftorth!(t::AbstractTensorMap; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) function _leftorth!(t::AbstractTensorMap, ::Nothing; kwargs...) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index be6d8438f..ad95a31d2 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -1,27 +1,12 @@ # Algorithm selection # ------------------- -for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, :svd_full, - :svd_compact, :svd_vals, :svd_trunc) - @eval function copy_input(::typeof($f), t::AbstractTensorMap{<:BlasFloat}) - T = factorisation_scalartype($f, t) - return copy_oftype(t, T) - end - f! = Symbol(f, :!) - # TODO: can we move this to MAK? - @eval function select_algorithm(::typeof($f!), t::AbstractTensorMap, alg::Alg=nothing; - kwargs...) where {Alg} - return select_algorithm($f!, typeof(t), alg; kwargs...) - end - @eval function select_algorithm(::typeof($f!), ::Type{T}, alg::Alg=nothing; - kwargs...) where {T<:AbstractTensorMap,Alg} - return select_algorithm($f!, blocktype(T), alg; kwargs...) - end -end - -for f in (:qr, :lq, :svd, :eig, :eigh, :polar) - default_f_algorithm = Symbol(:default_, f, :_algorithm) - @eval function $default_f_algorithm(::Type{T}; kwargs...) where {T<:AbstractTensorMap} - return $default_f_algorithm(blocktype(T); kwargs...) +for f! in + [:svd_compact!, :svd_full!, :svd_trunc!, :svd_vals!, :qr_compact!, :qr_full!, :qr_null!, + :lq_compact!, :lq_full!, :lq_null!, :eig_full!, :eig_trunc!, :eig_vals!, :eigh_full!, + :eigh_trunc!, :eigh_vals!, :left_polar!, :right_polar!] + @eval function default_algorithm(::typeof($f!), ::Type{T}; + kwargs...) where {T<:AbstractTensorMap} + return default_algorithm($f!, blocktype(T); kwargs...) end end From a7a9b8107c5082f9004bef1fc5e51deb49bfdc5d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 6 Aug 2025 16:00:03 -0400 Subject: [PATCH 053/126] Implement singular- and eigenvalues --- src/tensors/factorizations/factorizations.jl | 7 +-- .../factorizations/matrixalgebrakit.jl | 45 ++++++++++++++++--- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index df9e0bdc5..40d16c447 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -27,13 +27,14 @@ using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrateg import MatrixAlgebraKit: default_algorithm, copy_input, check_input, initialize_output, qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, - svd_compact!, svd_full!, svd_trunc!, - eig_full!, eig_trunc!, eigh_full!, eigh_trunc!, + svd_compact!, svd_full!, svd_trunc!, svd_vals!, + eigh_full!, eigh_trunc!, eigh_vals!, + eig_full!, eig_trunc!, eig_vals!, left_polar!, left_orth_polar!, right_polar!, right_orth_polar!, left_null_svd!, right_null_svd!, left_orth!, right_orth!, left_null!, right_null!, truncate!, findtruncated, findtruncated_sorted, - diagview + diagview, isisometry include("utility.jl") include("interface.jl") diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index ad95a31d2..45979a206 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -42,8 +42,8 @@ for f! in (:qr_compact!, :qr_full!, end end -# Handle these separately because single N instead of tuple -for f! in (:qr_null!, :lq_null!) +# Handle these separately because single output instead of tuple +for f! in (:qr_null!, :lq_null!, :svd_vals!, :eig_vals!, :eigh_vals!) @eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) check_input($f!, t, N) @@ -94,7 +94,12 @@ function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ): return nothing end -# TODO: svd_vals +function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict) + @check_scalar S t real + V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) + @check_space(S, V_cod ← V_dom) + return nothing +end function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_cod = fuse(codomain(t)) @@ -114,12 +119,14 @@ function initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, return U, S, Vᴴ end -function initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, - alg::TruncatedAlgorithm) +function initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, t, alg.alg) end -# TODO: svd_vals +function initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) + V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) + return DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) +end function svd_trunc!(t::AbstractTensorMap, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) @@ -162,6 +169,20 @@ function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV) return nothing end +function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap) + @check_scalar D t real + V_D = fuse(domain(t)) + @check_space(D, V_D ← V_D) + return nothing +end + +function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap) + @check_scalar D t complex + V_D = fuse(domain(t)) + @check_space(D, V_D ← V_D) + return nothing +end + function initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_D = fuse(domain(t)) T = real(scalartype(t)) @@ -178,6 +199,18 @@ function initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::Abstract return D, V end +function initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) + V_D = fuse(domain(t)) + T = real(scalartype(t)) + D = DiagonalTensorMap{Tc}(undef, V_D) +end + +function initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) + V_D = fuse(domain(t)) + Tc = complex(scalartype(t)) + D = DiagonalTensorMap{Tc}(undef, V_D) +end + function initialize_output(::typeof(eigh_trunc!), t::AbstractTensorMap, alg::TruncatedAlgorithm) return initialize_output(eigh_full!, t, alg.alg) From dbc5e14904ebc99098ae1fda75e77ae2f938100d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 6 Aug 2025 16:38:01 -0400 Subject: [PATCH 054/126] small fixes --- src/tensors/factorizations/implementations.jl | 38 +++++++++++++++---- .../factorizations/matrixalgebrakit.jl | 4 +- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index a53fd1680..55d40fad8 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -5,7 +5,7 @@ _kindof(::Polar) = :polar leftorth!(t::AbstractTensorMap; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) -function _leftorth!(t::AbstractTensorMap, ::Nothing; kwargs...) +function _leftorth!(t::AbstractTensorMap, alg::Nothing,; kwargs...) return isempty(kwargs) ? left_orth!(t) : left_orth!(t; trunc=(; kwargs...)) end function _leftorth!(t::AbstractTensorMap, alg::Union{QL,QLpos}; kwargs...) @@ -22,9 +22,14 @@ end function _leftorth!(t, alg::OFA; kwargs...) trunc = isempty(kwargs) ? nothing : (; kwargs...) + Base.depwarn(lazy"$alg is deprecated", :leftorth!) + kind = _kindof(alg) if kind == :svd - return left_orth!(t; kind, alg_svd=alg, trunc) + alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg === SDD() ? LAPACK_DivideAndConquer() : + throw(ArgumentError(lazy"Unknown algorithm $alg")) + return left_orth!(t; kind, alg_svd, trunc) elseif kind == :qr alg_qr = (; positive=(alg == QRpos())) return left_orth!(t; kind, alg_qr, trunc) @@ -47,7 +52,10 @@ function leftnull!(t::AbstractTensorMap; kind = _kindof(alg) if kind == :svd - return left_null!(t; kind, alg_svd=alg, trunc) + alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg === SDD() ? LAPACK_DivideAndConquer() : + throw(ArgumentError(lazy"Unknown algorithm $alg")) + return left_null!(t; kind, alg_svd, trunc) elseif kind == :qr alg_qr = (; positive=(alg == QRpos())) return left_null!(t; kind, alg_qr, trunc) @@ -76,7 +84,10 @@ function rightorth!(t::AbstractTensorMap; kind = _kindof(alg) if kind == :svd - return right_orth!(t; kind, alg_svd=alg, trunc) + alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg === SDD() ? LAPACK_DivideAndConquer() : + throw(ArgumentError(lazy"Unknown algorithm $alg")) + return right_orth!(t; kind, alg_svd, trunc) elseif kind == :lq alg_lq = (; positive=(alg == LQpos())) return right_orth!(t; kind, alg_lq, trunc) @@ -97,7 +108,10 @@ function rightnull!(t::AbstractTensorMap; kind = _kindof(alg) if kind == :svd - return right_null!(t; kind, alg_svd=alg, trunc) + alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg === SDD() ? LAPACK_DivideAndConquer() : + throw(ArgumentError(lazy"Unknown algorithm $alg")) + return right_null!(t; kind, alg_svd, trunc) elseif kind == :lq alg_lq = (; positive=(alg == LQpos())) return right_null!(t; kind, alg_lq, trunc) @@ -121,6 +135,7 @@ end function eig!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eig!) + if trunc == notrunc() return eig_full!(t; kwargs...) else @@ -134,13 +149,20 @@ end # Singular value decomposition # ---------------------------- -function tsvd!(t::AbstractTensorMap; trunc=notrunc(), p=nothing, kwargs...) +function tsvd!(t::AbstractTensorMap; trunc=notrunc(), p=nothing, alg=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) isnothing(p) || Base.depwarn("p is no longer supported", :tsvd!) + if alg isa OFA + Base.depwarn(lazy"$alg is deprecated", :tsvd!) + alg = alg === SVD() ? LAPACK_QRIteration() : + alg === SDD() ? LAPACK_DivideAndConquer() : + throw(ArgumentError(lazy"Unknown algorithm $alg")) + end + if trunc == notrunc() - return svd_compact!(t; kwargs...) + return svd_compact!(t; alg, kwargs...) else - return svd_trunc!(t; trunc, kwargs...) + return svd_trunc!(t; trunc, alg, kwargs...) end end diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 45979a206..f5429cf44 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -440,11 +440,11 @@ end # Needed to get algorithm selection to behave function left_orth_polar!(t::AbstractTensorMap, VC, alg) - alg′ = select_algorithm(left_polar!, t, alg) + alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg) return left_orth_polar!(t, VC, alg′) end function right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg) - alg′ = select_algorithm(right_polar!, t, alg) + alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg) return right_orth_polar!(t, CVᴴ, alg′) end From a6a04449020653406b4c049938db1d4ce18d91e3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 6 Aug 2025 16:40:17 -0400 Subject: [PATCH 055/126] fix docstring --- src/spaces/vectorspaces.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spaces/vectorspaces.jl b/src/spaces/vectorspaces.jl index 92fd780ef..847182536 100644 --- a/src/spaces/vectorspaces.jl +++ b/src/spaces/vectorspaces.jl @@ -154,8 +154,8 @@ const oplus = ⊕ ⊖(V::ElementarySpace, W::ElementarySpace) -> X::ElementarySpace ominus(V::ElementarySpace, W::ElementarySpace) -> X::ElementarySpace -Return the set difference of two elementary spaces, i.e. an instance `X::ElementarySpace` -such that `V = W ⊕ X`. +Return a space that is equivalent to the orthogonal complement of `W` in `V`, +i.e. an instance `X::ElementarySpace` such that `V = W ⊕ X`. """ ⊖(V₁::S, V₂::S) where {S<:ElementarySpace} ⊖(V₁::VectorSpace, V₂::VectorSpace) = ⊖(promote(V₁, V₂)...) From bf6aaaaf8cde308958afc064bca5e772c8e269d4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 6 Aug 2025 16:42:59 -0400 Subject: [PATCH 056/126] remove leftpolar and rightpolar --- src/TensorKit.jl | 4 ++-- src/tensors/factorizations/factorizations.jl | 1 - src/tensors/factorizations/implementations.jl | 4 ---- src/tensors/factorizations/interface.jl | 16 +--------------- test/tensors.jl | 4 ++-- 5 files changed, 5 insertions(+), 24 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index b087de47e..7888d9a66 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -70,8 +70,8 @@ export inner, dot, norm, normalize, normalize!, tr # factorizations export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby! -export leftorth, rightorth, leftnull, rightnull, leftpolar, rightpolar, - leftorth!, rightorth!, leftnull!, rightnull!, leftpolar!, rightpolar!, +export leftorth, rightorth, leftnull, rightnull, + leftorth!, rightorth!, leftnull!, rightnull!, tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 40d16c447..eb1913cad 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -7,7 +7,6 @@ export eig, eig!, eigh, eigh! export tsvd, tsvd!, svdvals, svdvals! export leftorth, leftorth!, rightorth, rightorth! export leftnull, leftnull!, rightnull, rightnull! -export leftpolar, leftpolar!, rightpolar, rightpolar! export copy_oftype, permutedcopy_oftype export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index 55d40fad8..20e5c3135 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -64,8 +64,6 @@ function leftnull!(t::AbstractTensorMap; end end -leftpolar!(t::AbstractTensorMap; kwargs...) = left_polar!(t; kwargs...) - function rightorth!(t::AbstractTensorMap; alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar,Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || @@ -120,8 +118,6 @@ function rightnull!(t::AbstractTensorMap; end end -rightpolar!(t::AbstractTensorMap; kwargs...) = right_polar!(t; kwargs...) - # Eigenvalue decomposition # ------------------------ function eigh!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) diff --git a/src/tensors/factorizations/interface.jl b/src/tensors/factorizations/interface.jl index 56d462887..78c40e517 100644 --- a/src/tensors/factorizations/interface.jl +++ b/src/tensors/factorizations/interface.jl @@ -182,20 +182,6 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `InnerProductStyle(t) === EuclideanInnerProduct()`. """ rightnull, rightnull! -@doc """ - leftpolar(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> W, P - leftpolar!(t::AbstractTensorMap; kwargs...) -> W, P - -Compute the polar decomposition of tensor `t` as linear map from `rightind` to `leftind`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `eigh!(t)`. - -See also [`rightpolar(!)`](@ref rightpolar). - -""" leftpolar, leftpolar! - @doc """ eigen(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> D, V eigen!(t::AbstractTensorMap; kwargs...) -> D, V @@ -230,7 +216,7 @@ meaningless. """ isposdef(::AbstractTensorMap), isposdef!(::AbstractTensorMap) for f in - (:tsvd, :eig, :eigh, :eigen, :leftorth, :rightorth, :leftpolar, :rightpolar, :leftnull, + (:tsvd, :eig, :eigh, :eigen, :leftorth, :rightorth, :left_polar, :right_polar, :leftnull, :rightnull, :isposdef) f! = Symbol(f, :!) @eval function $f(t::AbstractTensorMap, p::Index2Tuple; kwargs...) diff --git a/test/tensors.jl b/test/tensors.jl index 5d6980ddb..dcb5ee2cf 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -686,8 +686,8 @@ for V in spacelist for T in (Float32, ComplexF64) tA = rand(T, V1 ⊗ V3, V1 ⊗ V3) tB = rand(T, V2 ⊗ V4, V2 ⊗ V4) - tA = 3 // 2 * leftpolar(tA)[1] - tB = 1 // 5 * leftpolar(tB)[1] + tA = 3 // 2 * left_polar(tA)[1] + tB = 1 // 5 * left_polar(tB)[1] tC = rand(T, V1 ⊗ V3, V2 ⊗ V4) t = @constinferred sylvester(tA, tB, tC) @test codomain(t) == V1 ⊗ V3 From 017318eb073a3160b4d15f6d2008e5758434bb81 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 6 Aug 2025 16:44:14 -0400 Subject: [PATCH 057/126] format --- src/tensors/factorizations/implementations.jl | 2 +- src/tensors/factorizations/interface.jl | 3 ++- src/tensors/factorizations/matrixalgebrakit.jl | 16 ++++++++++------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index 20e5c3135..02c59b492 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -5,7 +5,7 @@ _kindof(::Polar) = :polar leftorth!(t::AbstractTensorMap; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) -function _leftorth!(t::AbstractTensorMap, alg::Nothing,; kwargs...) +function _leftorth!(t::AbstractTensorMap, alg::Nothing, ; kwargs...) return isempty(kwargs) ? left_orth!(t) : left_orth!(t; trunc=(; kwargs...)) end function _leftorth!(t::AbstractTensorMap, alg::Union{QL,QLpos}; kwargs...) diff --git a/src/tensors/factorizations/interface.jl b/src/tensors/factorizations/interface.jl index 78c40e517..fc757a298 100644 --- a/src/tensors/factorizations/interface.jl +++ b/src/tensors/factorizations/interface.jl @@ -216,7 +216,8 @@ meaningless. """ isposdef(::AbstractTensorMap), isposdef!(::AbstractTensorMap) for f in - (:tsvd, :eig, :eigh, :eigen, :leftorth, :rightorth, :left_polar, :right_polar, :leftnull, + (:tsvd, :eig, :eigh, :eigen, :leftorth, :rightorth, :left_polar, :right_polar, + :leftnull, :rightnull, :isposdef) f! = Symbol(f, :!) @eval function $f(t::AbstractTensorMap, p::Index2Tuple; kwargs...) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index f5429cf44..91aa8a5c0 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -119,11 +119,13 @@ function initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, return U, S, Vᴴ end -function initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, alg::TruncatedAlgorithm) +function initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, + alg::TruncatedAlgorithm) return initialize_output(svd_compact!, t, alg.alg) end -function initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) +function initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, + alg::AbstractAlgorithm) V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) return DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) end @@ -199,16 +201,18 @@ function initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::Abstract return D, V end -function initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) +function initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, + alg::AbstractAlgorithm) V_D = fuse(domain(t)) T = real(scalartype(t)) - D = DiagonalTensorMap{Tc}(undef, V_D) + return D = DiagonalTensorMap{Tc}(undef, V_D) end -function initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) +function initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, + alg::AbstractAlgorithm) V_D = fuse(domain(t)) Tc = complex(scalartype(t)) - D = DiagonalTensorMap{Tc}(undef, V_D) + return D = DiagonalTensorMap{Tc}(undef, V_D) end function initialize_output(::typeof(eigh_trunc!), t::AbstractTensorMap, From fe1da6603f2da17f8d8a2617079a087ad5ad9e75 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 18 Aug 2025 15:40:03 +0200 Subject: [PATCH 058/126] Tests passing --- Project.toml | 6 +-- src/TensorKit.jl | 1 + .../factorizations/matrixalgebrakit.jl | 46 +++++++++---------- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 6843ce3db..a9e1f7653 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" @@ -34,15 +33,12 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.2.5" +MatrixAlgebraKit = "0.3" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" -<<<<<<< HEAD -======= ScopedValues = "1.3.0" SparseArrays = "1" ->>>>>>> 3f9c871 (Add scheduler support) Strided = "2" TensorKitSectors = "0.1.4, 0.2" TensorOperations = "5.1" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 7888d9a66..038bdce0e 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -72,6 +72,7 @@ export inner, dot, norm, normalize, normalize!, tr export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby! export leftorth, rightorth, leftnull, rightnull, leftorth!, rightorth!, leftnull!, rightnull!, + left_polar, left_polar!, right_polar, right_polar!, tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 91aa8a5c0..1f40309b6 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -26,7 +26,7 @@ for f! in (:qr_compact!, :qr_full!, :svd_compact!, :svd_full!, :left_polar!, :left_orth_polar!, :right_polar!, :right_orth_polar!) @eval function $f!(t::AbstractTensorMap, F, alg::AbstractAlgorithm) - check_input($f!, t, F) + check_input($f!, t, F, alg) foreachblock(t, F...) do _, bs factors = Base.tail(bs) @@ -45,7 +45,7 @@ end # Handle these separately because single output instead of tuple for f! in (:qr_null!, :lq_null!, :svd_vals!, :eig_vals!, :eigh_vals!) @eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) - check_input($f!, t, N) + check_input($f!, t, N, alg) foreachblock(t, N) do _, (b, n) n′ = $f!(b, n, alg) @@ -63,7 +63,7 @@ end const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} const _T_USVᴴ_diag = Tuple{<:AbstractTensorMap,<:DiagonalTensorMap,<:AbstractTensorMap} -function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ) +function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ, ::AbstractAlgorithm) # scalartype checks @check_scalar U t @check_scalar S t real @@ -79,7 +79,7 @@ function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T return nothing end -function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ_diag) +function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ_diag, ::AbstractAlgorithm) # scalartype checks @check_scalar U t @check_scalar S t real @@ -94,7 +94,7 @@ function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ): return nothing end -function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict) +function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict, ::AbstractAlgorithm) @check_scalar S t real V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) @check_space(S, V_cod ← V_dom) @@ -139,7 +139,7 @@ end # ------------------------ const _T_DV = Tuple{<:DiagonalTensorMap,<:AbstractTensorMap} -function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV) +function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV, ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -155,7 +155,7 @@ function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV) return nothing end -function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV) +function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV, ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -171,14 +171,14 @@ function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV) return nothing end -function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap) +function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, ::AbstractAlgorithm) @check_scalar D t real V_D = fuse(domain(t)) @check_space(D, V_D ← V_D) return nothing end -function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap) +function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, ::AbstractAlgorithm) @check_scalar D t complex V_D = fuse(domain(t)) @check_space(D, V_D ← V_D) @@ -239,7 +239,7 @@ end # ---------------- const _T_QR = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR) +function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR, ::AbstractAlgorithm) # scalartype checks @check_scalar Q t @check_scalar R t @@ -252,7 +252,7 @@ function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR) return nothing end -function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR) +function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR, ::AbstractAlgorithm) # scalartype checks @check_scalar Q t @check_scalar R t @@ -265,7 +265,7 @@ function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR) return nothing end -function check_input(::typeof(qr_null!), t::AbstractTensorMap, N::AbstractTensorMap) +function check_input(::typeof(qr_null!), t::AbstractTensorMap, N::AbstractTensorMap, ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -302,7 +302,7 @@ end # ---------------- const _T_LQ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ) +function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ, ::AbstractAlgorithm) # scalartype checks @check_scalar L t @check_scalar Q t @@ -315,7 +315,7 @@ function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ) return nothing end -function check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)::_T_LQ) +function check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)::_T_LQ, ::AbstractAlgorithm) # scalartype checks @check_scalar L t @check_scalar Q t @@ -328,7 +328,7 @@ function check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)::_T_LQ) return nothing end -function check_input(::typeof(lq_null!), t::AbstractTensorMap, N) +function check_input(::typeof(lq_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -367,7 +367,7 @@ const _T_WP = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} const _T_PWᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} using MatrixAlgebraKit: PolarViaSVD -function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP) +function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP, ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) @@ -382,7 +382,7 @@ function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP) return nothing end -function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, (W, P)::_T_WP) +function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, (W, P)::_T_WP, ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) @@ -404,7 +404,7 @@ function initialize_output(::typeof(left_polar!), t::AbstractTensorMap, ::Abstra return W, P end -function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ) +function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) @@ -419,7 +419,7 @@ function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T return nothing end -function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ) +function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) @@ -457,7 +457,7 @@ end const _T_VC = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} const _T_CVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC) +function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC, ::AbstractAlgorithm) # scalartype checks @check_scalar V t isnothing(C) || @check_scalar C t @@ -470,7 +470,7 @@ function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC) return nothing end -function check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)::_T_CVᴴ) +function check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)::_T_CVᴴ, ::AbstractAlgorithm) # scalartype checks isnothing(C) || @check_scalar C t @check_scalar Vᴴ t @@ -499,7 +499,7 @@ end # Nullspace # --------- -function check_input(::typeof(left_null!), t::AbstractTensorMap, N) +function check_input(::typeof(left_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -511,7 +511,7 @@ function check_input(::typeof(left_null!), t::AbstractTensorMap, N) return nothing end -function check_input(::typeof(right_null!), t::AbstractTensorMap, N) +function check_input(::typeof(right_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) @check_scalar N t # space checks From 628ab71031b6274409a09060b5e4496b0e0ab72e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 Aug 2025 12:35:01 +0200 Subject: [PATCH 059/126] Added a few more eig/eigh/svd tests --- test/diagonal.jl | 10 ++++++++++ test/tensors.jl | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/test/diagonal.jl b/test/diagonal.jl index e5fcaa430..0f67836dd 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -200,6 +200,16 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @test all(((s, t),) -> isapprox(s, t), zip(values(LinearAlgebra.eigvals(D)), values(LinearAlgebra.eigvals(t)))) + D, W = @constinferred eig!(t) + @test D === t + @test W == one(t) + @test t * W ≈ W * D + D2, V2 = @constinferred eigh!(t2) + if T <: Real + @test D2 === t2 + end + @test V2 == one(t) + @test t2 * V2 ≈ V2 * D2 end @testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QL()) Q, R = @constinferred leftorth(t; alg=alg) diff --git a/test/tensors.jl b/test/tensors.jl index dcb5ee2cf..b827d9f12 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -490,6 +490,11 @@ for V in spacelist for (c, b) in s @test b ≈ s′[c] end + s = LinearAlgebra.svdvals(t2') + s′ = LinearAlgebra.diag(S') + for (c, b) in s + @test b ≈ s′[c] + end end @testset "cond and rank" begin t2 = permute(t, ((3, 4, 2), (1, 5))) @@ -507,6 +512,10 @@ for V in spacelist λmax = maximum(s -> maximum(abs, s), values(vals)) λmin = minimum(s -> minimum(abs, s), values(vals)) @test cond(t4) ≈ λmax / λmin + vals = LinearAlgebra.eigvals(t4') + λmax = maximum(s -> maximum(abs, s), values(vals)) + λmin = minimum(s -> minimum(abs, s), values(vals)) + @test cond(t4') ≈ λmax / λmin end end @testset "empty tensor" begin From 660ac632b842710b90269e90f5672affa2bfd838 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 Aug 2025 14:20:17 +0200 Subject: [PATCH 060/126] Excise MatrixAlgebra module entirely to use MatrixAlgebraKit --- src/auxiliary/linalg.jl | 330 ------------------ src/auxiliary/random.jl | 2 +- src/tensors/factorizations/factorizations.jl | 3 +- src/tensors/factorizations/implementations.jl | 2 +- .../factorizations/matrixalgebrakit.jl | 6 +- src/tensors/linalg.jl | 12 +- 6 files changed, 13 insertions(+), 342 deletions(-) diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index 82e8600f0..5be245fc5 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -46,35 +46,6 @@ Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg const OFA = OrthogonalFactorizationAlgorithm const SVDAlg = Union{SVD,SDD} -# Matrix algebra: entrypoint for calling matrix methods from within tensor implementations -#------------------------------------------------------------------------------------------ -module MatrixAlgebra -# TODO: all methods that we define here will need an extended version for CuMatrix/ROCMatrix in the -# CUDA/AMD package extension. - -# TODO: other methods to include here: -# mul! (possibly call matmul! instead) -# adjoint! -# sylvester -# exp! -# schur!? -# - -using LinearAlgebra -using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, checksquare - -using ..TensorKit: OrthogonalFactorizationAlgorithm, - QL, QLpos, QR, QRpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar - -# only defined in >v1.7 -@static if VERSION < v"1.7-" - _rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im) - _argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2] -else - _argmax(f, domain) = argmax(f, domain) -end - -# TODO: define for CuMatrix if we support this function one!(A::StridedMatrix) length(A) > 0 || return A copyto!(A, LinearAlgebra.I) @@ -83,304 +54,3 @@ end safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s)) safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) - -function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real) - iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) - m, n = size(A) - k = min(m, n) - A, T = LAPACK.geqrt!(A, min(minimum(size(A)), 36)) - Q = similar(A, m, k) - for j in 1:k - for i in 1:m - Q[i, j] = i == j - end - end - Q = LAPACK.gemqrt!('L', 'N', A, T, Q) - R = triu!(A[1:k, :]) - - if isa(alg, QRpos) - @inbounds for j in 1:k - s = safesign(R[j, j]) - @simd for i in 1:m - Q[i, j] *= s - end - end - @inbounds for j in size(R, 2):-1:1 - for i in 1:min(k, j) - R[i, j] = R[i, j] * conj(safesign(R[i, i])) - end - end - end - return Q, R -end - -function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{QL,QLpos}, atol::Real) - iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) - m, n = size(A) - @assert m >= n - - nhalf = div(n, 2) - #swap columns in A - @inbounds for j in 1:nhalf, i in 1:m - A[i, j], A[i, n + 1 - j] = A[i, n + 1 - j], A[i, j] - end - Q, R = leftorth!(A, isa(alg, QL) ? QR() : QRpos(), atol) - - #swap columns in Q - @inbounds for j in 1:nhalf, i in 1:m - Q[i, j], Q[i, n + 1 - j] = Q[i, n + 1 - j], Q[i, j] - end - #swap rows and columns in R - @inbounds for j in 1:nhalf, i in 1:n - R[i, j], R[n + 1 - i, n + 1 - j] = R[n + 1 - i, n + 1 - j], R[i, j] - end - if isodd(n) - j = nhalf + 1 - @inbounds for i in 1:nhalf - R[i, j], R[n + 1 - i, j] = R[n + 1 - i, j], R[i, j] - end - end - return Q, R -end - -function leftorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real) - U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A) : LAPACK.gesdd!('S', A) - if isa(alg, Union{SVD,SDD}) - n = count(s -> s .> atol, S) - if n != length(S) - return U[:, 1:n], lmul!(Diagonal(S[1:n]), V[1:n, :]) - else - return U, lmul!(Diagonal(S), V) - end - else - iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) - # TODO: check Lapack to see if we can recycle memory of A - Q = mul!(A, U, V) - Sq = map!(sqrt, S, S) - SqV = lmul!(Diagonal(Sq), V) - R = SqV' * SqV - return Q, R - end -end - -function leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{QR,QRpos}, atol::Real) - iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) - m, n = size(A) - m >= n || throw(ArgumentError("no null space if less rows than columns")) - - A, T = LAPACK.geqrt!(A, min(minimum(size(A)), 36)) - N = similar(A, m, max(0, m - n)) - fill!(N, 0) - for k in 1:(m - n) - N[n + k, k] = 1 - end - return N = LAPACK.gemqrt!('L', 'N', A, T, N) -end - -function leftnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real) - size(A, 2) == 0 && return one!(similar(A, (size(A, 1), size(A, 1)))) - U, S, V = alg isa SVD ? LAPACK.gesvd!('A', 'N', A) : LAPACK.gesdd!('A', A) - indstart = count(>(atol), S) + 1 - return U[:, indstart:end] -end - -function rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos,RQ,RQpos}, - atol::Real) - iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) - # TODO: geqrfp seems a bit slower than geqrt in the intermediate region around - # matrix size 100, which is the interesting region. => Investigate and fix - m, n = size(A) - k = min(m, n) - At = transpose!(similar(A, n, m), A) - - if isa(alg, RQ) || isa(alg, RQpos) - @assert m <= n - - mhalf = div(m, 2) - # swap columns in At - @inbounds for j in 1:mhalf, i in 1:n - At[i, j], At[i, m + 1 - j] = At[i, m + 1 - j], At[i, j] - end - Qt, Rt = leftorth!(At, isa(alg, RQ) ? QR() : QRpos(), atol) - - @inbounds for j in 1:mhalf, i in 1:n - Qt[i, j], Qt[i, m + 1 - j] = Qt[i, m + 1 - j], Qt[i, j] - end - @inbounds for j in 1:mhalf, i in 1:m - Rt[i, j], Rt[m + 1 - i, m + 1 - j] = Rt[m + 1 - i, m + 1 - j], Rt[i, j] - end - if isodd(m) - j = mhalf + 1 - @inbounds for i in 1:mhalf - Rt[i, j], Rt[m + 1 - i, j] = Rt[m + 1 - i, j], Rt[i, j] - end - end - Q = transpose!(A, Qt) - R = transpose!(similar(A, (m, m)), Rt) # TODO: efficient in place - return R, Q - else - Qt, Lt = leftorth!(At, alg', atol) - if m > n - L = transpose!(A, Lt) - Q = transpose!(similar(A, (n, n)), Qt) # TODO: efficient in place - else - Q = transpose!(A, Qt) - L = transpose!(similar(A, (m, m)), Lt) # TODO: efficient in place - end - return L, Q - end -end - -function rightorth!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD,Polar}, atol::Real) - U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A) : LAPACK.gesdd!('S', A) - if isa(alg, Union{SVD,SDD}) - n = count(s -> s .> atol, S) - if n != length(S) - return rmul!(U[:, 1:n], Diagonal(S[1:n])), V[1:n, :] - else - return rmul!(U, Diagonal(S)), V - end - else - iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) - Q = mul!(A, U, V) - Sq = map!(sqrt, S, S) - USq = rmul!(U, Diagonal(Sq)) - L = USq * USq' - return L, Q - end -end - -function rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{LQ,LQpos}, atol::Real) - iszero(atol) || throw(ArgumentError("nonzero atol not supported by $alg")) - m, n = size(A) - k = min(m, n) - At = adjoint!(similar(A, n, m), A) - At, T = LAPACK.geqrt!(At, min(k, 36)) - N = similar(A, max(n - m, 0), n) - fill!(N, 0) - for k in 1:(n - m) - N[k, m + k] = 1 - end - return N = LAPACK.gemqrt!('R', eltype(At) <: Real ? 'T' : 'C', At, T, N) -end - -function rightnull!(A::StridedMatrix{<:BlasFloat}, alg::Union{SVD,SDD}, atol::Real) - size(A, 1) == 0 && return one!(similar(A, (size(A, 2), size(A, 2)))) - U, S, V = alg isa SVD ? LAPACK.gesvd!('N', 'A', A) : LAPACK.gesdd!('A', A) - indstart = count(>(atol), S) + 1 - return V[indstart:end, :] -end - -function svd!(A::StridedMatrix{T}, alg::Union{SVD,SDD}) where {T<:BlasFloat} - # fix another type instability in LAPACK wrappers - TT = Tuple{Matrix{T},Vector{real(T)},Matrix{T}} - U, S, V = alg isa SVD ? LAPACK.gesvd!('S', 'S', A)::TT : LAPACK.gesdd!('S', A)::TT - return U, S, V -end - -function eig!(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) where {T<:BlasReal} - n = checksquare(A) - n == 0 && return zeros(Complex{T}, 0), zeros(Complex{T}, 0, 0) - - A, DR, DI, VL, VR, _ = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') : - (scale ? 'S' : 'N'), 'N', 'V', 'N', A) - D = complex.(DR, DI) - V = zeros(Complex{T}, n, n) - j = 1 - while j <= n - if DI[j] == 0 - vr = view(VR, :, j) - s = conj(sign(_argmax(abs, vr))) - V[:, j] .= s .* vr - else - vr = view(VR, :, j) - vi = view(VR, :, j + 1) - s = conj(sign(_argmax(abs, vr))) # vectors coming from lapack have already real absmax component - V[:, j] .= s .* (vr .+ im .* vi) - V[:, j + 1] .= s .* (vr .- im .* vi) - j += 1 - end - j += 1 - end - return D, V -end - -function eig!(A::StridedMatrix{T}; permute::Bool=true, - scale::Bool=true) where {T<:BlasComplex} - n = checksquare(A) - n == 0 && return zeros(T, 0), zeros(T, 0, 0) - D, V = LAPACK.geevx!(permute ? (scale ? 'B' : 'P') : (scale ? 'S' : 'N'), 'N', 'V', 'N', - A)[[2, 4]] - for j in 1:n - v = view(V, :, j) - s = conj(sign(_argmax(abs, v))) - v .*= s - end - return D, V -end - -function eigh!(A::StridedMatrix{T}) where {T<:BlasFloat} - n = checksquare(A) - n == 0 && return zeros(real(T), 0), zeros(T, 0, 0) - D, V = LAPACK.syevr!('V', 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0) - for j in 1:n - v = view(V, :, j) - s = conj(sign(_argmax(abs, v))) - v .*= s - end - return D, V -end - -## Old stuff and experiments - -# using LinearAlgebra: BlasFloat, Char, BlasInt, LAPACK, LAPACKException, -# DimensionMismatch, SingularException, PosDefException, chkstride1, -# checksquare, -# triu! - -# TODO: reconsider the following implementation -# Unfortunately, geqrfp seems a bit slower than geqrt in the intermediate region -# around matrix size 100, which is the interesting region. => Investigate and maybe fix -# function _leftorth!(A::StridedMatrix{<:BlasFloat}) -# m, n = size(A) -# A, τ = geqrfp!(A) -# Q = LAPACK.ormqr!('L', 'N', A, τ, eye(eltype(A), m, min(m, n))) -# R = triu!(A[1:min(m, n), :]) -# return Q, R -# end - -# geqrfp!: computes qrpos factorization, missing in Base -# geqrfp!(A::StridedMatrix{<:BlasFloat}) = -# ((m, n) = size(A); geqrfp!(A, similar(A, min(m, n)))) -# -# for (geqrfp, elty, relty) in -# ((:dgeqrfp_, :Float64, :Float64), (:sgeqrfp_, :Float32, :Float32), -# (:zgeqrfp_, :ComplexF64, :Float64), (:cgeqrfp_, :ComplexF32, :Float32)) -# @eval begin -# function geqrfp!(A::StridedMatrix{$elty}, tau::StridedVector{$elty}) -# chkstride1(A, tau) -# m, n = size(A) -# if length(tau) != min(m, n) -# throw(DimensionMismatch("tau has length $(length(tau)), but needs length $(min(m, n))")) -# end -# work = Vector{$elty}(1) -# lwork = BlasInt(-1) -# info = Ref{BlasInt}() -# for i = 1:2 # first call returns lwork as work[1] -# ccall((@blasfunc($geqrfp), liblapack), Nothing, -# (Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, -# Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}), -# Ref(m), Ref(n), A, Ref(max(1, stride(A, 2))), -# tau, work, Ref(lwork), info) -# chklapackerror(info[]) -# if i == 1 -# lwork = BlasInt(real(work[1])) -# resize!(work, lwork) -# end -# end -# A, tau -# end -# end -# end - -end diff --git a/src/auxiliary/random.jl b/src/auxiliary/random.jl index bc9df6f65..3289cdc3c 100644 --- a/src/auxiliary/random.jl +++ b/src/auxiliary/random.jl @@ -20,6 +20,6 @@ function randisometry!(rng::Random.AbstractRNG, A::AbstractMatrix) dims = size(A) dims[1] >= dims[2] || throw(DimensionMismatch("cannot create isometric matrix with dimensions $dims; isometry needs to be tall or square")) - Q, = MatrixAlgebra.leftorth!(Random.randn!(rng, A), QRpos(), 0) + Q, = leftorth!(Random.randn!(rng, A); alg=QRpos()) return copy!(A, Q) end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index eb1913cad..c4db4beb8 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -12,7 +12,6 @@ export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace using ..TensorKit using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock -using ..MatrixAlgebra: MatrixAlgebra using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals! import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian @@ -112,7 +111,7 @@ function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD}) V = zerovector!(similar(b.diag, lb, lb)) p = sortperm(b.diag; by=abs, rev=true) for (i, pi) in enumerate(p) - U[pi, i] = MatrixAlgebra.safesign(b.diag[pi]) + U[pi, i] = safesign(b.diag[pi]) V[i, pi] = 1 end Σ = abs.(view(b.diag, p)) diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index 02c59b492..d00703549 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -3,7 +3,7 @@ _kindof(::Union{QR,QRpos}) = :qr _kindof(::Union{LQ,LQpos}) = :lq _kindof(::Polar) = :polar -leftorth!(t::AbstractTensorMap; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) +leftorth!(t; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) function _leftorth!(t::AbstractTensorMap, alg::Nothing, ; kwargs...) return isempty(kwargs) ? left_orth!(t) : left_orth!(t; trunc=(; kwargs...)) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 1f40309b6..1513a1576 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -3,7 +3,7 @@ for f! in [:svd_compact!, :svd_full!, :svd_trunc!, :svd_vals!, :qr_compact!, :qr_full!, :qr_null!, :lq_compact!, :lq_full!, :lq_null!, :eig_full!, :eig_trunc!, :eig_vals!, :eigh_full!, - :eigh_trunc!, :eigh_vals!, :left_polar!, :right_polar!] + :eigh_trunc!, :eigh_vals!, :left_polar!, :right_polar!, :left_orth!, :right_orth!] @eval function default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T<:AbstractTensorMap} return default_algorithm($f!, blocktype(T); kwargs...) @@ -24,7 +24,9 @@ for f! in (:qr_compact!, :qr_full!, :lq_compact!, :lq_full!, :eig_full!, :eigh_full!, :svd_compact!, :svd_full!, - :left_polar!, :left_orth_polar!, :right_polar!, :right_orth_polar!) + :left_polar!, :left_orth_polar!, + :right_polar!, :right_orth_polar!, + :left_orth!, :right_orth!) @eval function $f!(t::AbstractTensorMap, F, alg::AbstractAlgorithm) check_input($f!, t, F, alg) diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 5eba8414c..b52e8b4bc 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -59,7 +59,7 @@ function one!(t::AbstractTensorMap) domain(t) == codomain(t) || throw(SectorMismatch("no identity if domain and codomain are different")) for (c, b) in blocks(t) - MatrixAlgebra.one!(b) + one!(b) end return t end @@ -106,7 +106,7 @@ function isomorphism!(t::AbstractTensorMap) domain(t) ≅ codomain(t) || throw(SpaceMismatch(lazy"domain and codomain are not isomorphic: $(space(t))")) for (_, b) in blocks(t) - MatrixAlgebra.one!(b) + one!(b) end return t end @@ -155,7 +155,7 @@ function isometry!(t::AbstractTensorMap) domain(t) ≾ codomain(t) || throw(SpaceMismatch(lazy"domain and codomain are not isometrically embeddable: $(space(t))")) for (_, b) in blocks(t) - MatrixAlgebra.one!(b) + one!(b) end return t end @@ -377,7 +377,7 @@ function Base.inv(t::AbstractTensorMap) T = float(scalartype(t)) tinv = similar(t, T, dom ← cod) for (c, b) in blocks(t) - binv = MatrixAlgebra.one!(block(tinv, c)) + binv = one!(block(tinv, c)) ldiv!(lu(b), binv) end return tinv @@ -449,11 +449,11 @@ for f in (:cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asi tf = similar(t, T) if T <: Real for (c, b) in blocks(t) - copy!(block(tf, c), real(MatrixAlgebra.$f(b))) + copy!(block(tf, c), real(MatrixAlgebraKit.$f(b))) end else for (c, b) in blocks(t) - copy!(block(tf, c), MatrixAlgebra.$f(b)) + copy!(block(tf, c), MatrixAlgebraKit.$f(b)) end end return tf From deda7efbc56004e96c94ab032794efd01e51c121 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 Aug 2025 15:10:21 +0200 Subject: [PATCH 061/126] Restore argmax and fix AD test --- src/auxiliary/linalg.jl | 2 ++ test/ad.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index 5be245fc5..9d8ac7818 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -54,3 +54,5 @@ end safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s)) safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) + +_argmax(f, domain) = argmax(f, domain) diff --git a/test/ad.jl b/test/ad.jl index 9f5eb2a5b..c5c2aabfc 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -447,7 +447,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV)) - c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), + c, = TensorKit._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) trunc = truncdim(round(Int, 2 * dim(c))) U, S, V = tsvd(C; trunc) From d60abfe71675de9e2d24aa2d0900eae92b7ccfe9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 Aug 2025 17:20:39 +0200 Subject: [PATCH 062/126] Respond to comments --- src/auxiliary/linalg.jl | 8 -------- src/tensors/factorizations/factorizations.jl | 6 ++++-- test/ad.jl | 3 +-- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index 9d8ac7818..8a20eb8ac 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -46,13 +46,5 @@ Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg const OFA = OrthogonalFactorizationAlgorithm const SVDAlg = Union{SVD,SDD} -function one!(A::StridedMatrix) - length(A) > 0 || return A - copyto!(A, LinearAlgebra.I) - return A -end - safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s)) safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) - -_argmax(f, domain) = argmax(f, domain) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index c4db4beb8..94124a04f 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -7,11 +7,11 @@ export eig, eig!, eigh, eigh! export tsvd, tsvd!, svdvals, svdvals! export leftorth, leftorth!, rightorth, rightorth! export leftnull, leftnull!, rightnull, rightnull! -export copy_oftype, permutedcopy_oftype +export copy_oftype, permutedcopy_oftype, one! export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace using ..TensorKit -using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock +using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock, one! using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals! import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian @@ -41,6 +41,8 @@ include("matrixalgebrakit.jl") include("truncation.jl") include("deprecations.jl") +TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A) + function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) t = permute(t, (p₁, p₂); copy=false) return isisometry(t) diff --git a/test/ad.jl b/test/ad.jl index c5c2aabfc..4013846d0 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -447,8 +447,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV)) - c, = TensorKit._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), - blocks(S)) + c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) trunc = truncdim(round(Int, 2 * dim(c))) U, S, V = tsvd(C; trunc) ΔU = randn(scalartype(U), space(U)) From de48d8747ad49424590275443eb8b080eee7e810 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 Aug 2025 18:04:21 +0200 Subject: [PATCH 063/126] Remove unneeded default_algorithm --- src/tensors/factorizations/matrixalgebrakit.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 1513a1576..2cd40b01a 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -3,7 +3,7 @@ for f! in [:svd_compact!, :svd_full!, :svd_trunc!, :svd_vals!, :qr_compact!, :qr_full!, :qr_null!, :lq_compact!, :lq_full!, :lq_null!, :eig_full!, :eig_trunc!, :eig_vals!, :eigh_full!, - :eigh_trunc!, :eigh_vals!, :left_polar!, :right_polar!, :left_orth!, :right_orth!] + :eigh_trunc!, :eigh_vals!, :left_polar!, :right_polar!] @eval function default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T<:AbstractTensorMap} return default_algorithm($f!, blocktype(T); kwargs...) From 6ee379bbaceb1d74a82c2645a69083eef20b7824 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 31 Aug 2025 20:26:25 +0200 Subject: [PATCH 064/126] clean up handling AdjointTensorMap --- src/tensors/factorizations/adjoint.jl | 90 ++++++++++++++++++++ src/tensors/factorizations/factorizations.jl | 32 +------ 2 files changed, 91 insertions(+), 31 deletions(-) create mode 100644 src/tensors/factorizations/adjoint.jl diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl new file mode 100644 index 000000000..918248c29 --- /dev/null +++ b/src/tensors/factorizations/adjoint.jl @@ -0,0 +1,90 @@ +# AdjointTensorMap +# ---------------- +# 1-arg functions +function initialize_output(::typeof(left_null!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return adjoint(initialize_output(right_null!, adjoint(t), alg)) +end +function initialize_output(::typeof(right_null!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return adjoint(initialize_output(left_null!, adjoint(t), alg)) +end + +function left_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) + right_null!(adjoint(t), adjoint(N), alg) + return N +end +function right_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) + left_null!(adjoint(t), adjoint(N), alg) + return N +end + +# 2-arg functions +for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!), + (:lq_full!, :lq_compact!, :right_polar!, :right_orth!)) + @eval function initialize_output(::typeof($left_f!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return reverse(adjoint.(initialize_output($right_f!, adjoint(t), alg))) + end + @eval function initialize_output(::typeof($right_f!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return reverse(adjoint.(initialize_output($left_f!, adjoint(t), alg))) + end + + @eval function $left_f!(t::AdjointTensorMap, + F::Tuple{AdjointTensorMap,AdjointTensorMap}, + alg::AbstractAlgorithm) + $right_f!(adjoint(t), reverse(adjoint.(F)), alg) + return F + end + @eval function $right_f!(t::AdjointTensorMap, + F::Tuple{AdjointTensorMap,AdjointTensorMap}, + alg::AbstractAlgorithm) + $left_f!(adjoint(t), reverse(adjoint.(F)), alg) + return F + end +end + +# 3-arg functions +for f! in (:svd_full!, :svd_compact!, :svd_trunc!) + @eval function initialize_output(::typeof($f!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return reverse(adjoint.(initialize_output($f!, adjoint(t), alg))) + end + _TS = f! === :svd_full! ? :AdjointTensorMap : DiagonalTensorMap + @eval function $f!(t::AdjointTensorMap, + F::Tuple{AdjointTensorMap,$_TS,AdjointTensorMap}, + alg::AbstractAlgorithm) + $f!(adjoint(t), reverse(adjoint.(F)), alg) + return F + end +end + +function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos()) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:leftorth!) + return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg'))) +end + +function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos()) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightorth!) + return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg'))) +end + +function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:leftnull!) + return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...)) +end + +function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightnull!) + return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...)) +end + +function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) + u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) + return adjoint(vt), adjoint(s), adjoint(u), err +end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 94124a04f..db1b1c3c7 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -40,6 +40,7 @@ include("implementations.jl") include("matrixalgebrakit.jl") include("truncation.jl") include("deprecations.jl") +include("adjoint.jl") TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A) @@ -54,37 +55,6 @@ end #------------------------------------------------------------------------------------------ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} -# AdjointTensorMap -# ---------------- -function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftorth!) - return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg'))) -end - -function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightorth!) - return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg'))) -end - -function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftnull!) - return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...)) -end - -function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightnull!) - return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...)) -end - -function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) - u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) - return adjoint(vt), adjoint(s), adjoint(u), err -end - # DiagonalTensorMap # ----------------- function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...) From 75082ef48cee9e6553490b3230ddfd14bb50bc5f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 31 Aug 2025 20:29:42 +0200 Subject: [PATCH 065/126] more adjoint specializations --- src/tensors/factorizations/adjoint.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 918248c29..f47ad9549 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -19,6 +19,13 @@ function right_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgo return N end +function MatrixAlgebraKit.is_left_isometry(t::AdjointTensorMap; kwargs...) + return is_right_isometry(adjoint(t); kwargs...) +end +function MatrixAlgebraKit.is_right_isometry(t::AdjointTensorMap; kwargs...) + return is_left_isometry(adjoint(t); kwargs...) +end + # 2-arg functions for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!), (:lq_full!, :lq_compact!, :right_polar!, :right_orth!)) From 9fb55940b91d28db27ffc3b7702f7bf2b0ff9f14 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 31 Aug 2025 20:32:16 +0200 Subject: [PATCH 066/126] remove previous adjoint specializations --- src/tensors/factorizations/adjoint.jl | 29 --------------------------- 1 file changed, 29 deletions(-) diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index f47ad9549..08154d15e 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -66,32 +66,3 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) return F end end - -function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftorth!) - return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg'))) -end - -function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightorth!) - return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg'))) -end - -function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftnull!) - return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...)) -end - -function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightnull!) - return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...)) -end - -function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) - u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) - return adjoint(vt), adjoint(s), adjoint(u), err -end From 782a40d4706cedc5bf0d37cb311220fa3c7b0f4c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 4 Sep 2025 15:36:20 -0400 Subject: [PATCH 067/126] Add `similar(::DiagonalTensorMap, [::Type{T}]) -> DiagonalTensorMap` --- src/tensors/diagonal.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 5a3840f1b..abdfeebb7 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -76,6 +76,11 @@ function DiagonalTensorMap(t::AbstractTensorMap{T,S,1,1}) where {T,S} return d end +Base.similar(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain) +function Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T<:Number} + return DiagonalTensorMap(similar(d.data, T), d.domain) +end + # TODO: more constructors needed? # Special case adjoint: From 510833d2e44a60171911dad2c37db5c6cdd2a8e2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 5 Sep 2025 11:10:36 -0400 Subject: [PATCH 068/126] clean up handling DiagonalTensorMap --- Project.toml | 2 +- src/tensors/factorizations/diagonal.jl | 78 ++++++++++++++++++++ src/tensors/factorizations/factorizations.jl | 56 +------------- 3 files changed, 81 insertions(+), 55 deletions(-) create mode 100644 src/tensors/factorizations/diagonal.jl diff --git a/Project.toml b/Project.toml index a9e1f7653..679b217b5 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.3" +MatrixAlgebraKit = "0.3.1" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl new file mode 100644 index 000000000..cd70a1d5f --- /dev/null +++ b/src/tensors/factorizations/diagonal.jl @@ -0,0 +1,78 @@ +# DiagonalTensorMap +# ----------------- +_repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data) + +for f in [ + :svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, + :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, + :eigh_trunc, :eigh_vals, :left_polar, :right_polar, + ] + @eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) +end + +for f! in (:qr_full!, :qr_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!) + @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) + return d, similar(d) + end +end +for f! in (:lq_full!, :lq_compact!) + @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) + return similar(d), d + end +end + +for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!) + @eval function $f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm) + check_input($f!, d, F, alg) + $f!(_repack_diagonal(d), _repack_diagonal.(F), alg) + return F + end +end + +for f! in (:qr_full!, :qr_compact!) + @eval function check_input(::typeof($f!), d::AbstractTensorMap, (Q, R)::_T_QR, ::DiagonalAlgorithm) + @assert d isa DiagonalTensorMap + @assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap + @check_scalar Q d + @check_scalar R d + @check_space(Q, space(d)) + @check_space(R, space(d)) + + return nothing + end +end + +for f! in (:lq_full!, :lq_compact!) + @eval function check_input(::typeof($f!), d::AbstractTensorMap, (L, Q)::_T_LQ, ::DiagonalAlgorithm) + @assert d isa DiagonalTensorMap + @assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap + @check_scalar Q d + @check_scalar L d + @check_space(Q, space(d)) + @check_space(L, space(d)) + + return nothing + end +end + +# f_vals +# ------ + +for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) + @eval function $f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm) + check_input($f!, d, V, alg) + $f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg) + return V + end + @eval function initialize_output(::typeof($f!), d::DiagonalTensorMap, alg::DiagonalAlgorithm) + data = initialize_output($f!, _repack_diagonal(d), alg) + return DiagonalTensorMap(data, d.domain) + end +end + +function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, + ::DiagonalAlgorithm) + @check_scalar D t + @check_space D space(t) + return nothing +end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index db1b1c3c7..2d6349e51 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -21,7 +21,7 @@ using TensorOperations: Index2Tuple using MatrixAlgebraKit using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy, NoTruncation, TruncationKeepAbove, TruncationKeepBelow, - TruncationIntersection, TruncationKeepFiltered + TruncationIntersection, TruncationKeepFiltered, DiagonalAlgorithm import MatrixAlgebraKit: default_algorithm, copy_input, check_input, initialize_output, qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, @@ -41,6 +41,7 @@ include("matrixalgebrakit.jl") include("truncation.jl") include("deprecations.jl") include("adjoint.jl") +include("diagonal.jl") TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A) @@ -55,59 +56,6 @@ end #------------------------------------------------------------------------------------------ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} -# DiagonalTensorMap -# ----------------- -function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...) - @assert alg isa Union{QR,QL} - return one(d), d # TODO: this is only correct for `alg = QR()` or `alg = QL()` -end -function rightorth!(d::DiagonalTensorMap; alg=LQ(), kwargs...) - @assert alg isa Union{LQ,RQ} - return d, one(d) # TODO: this is only correct for `alg = LQ()` or `alg = RQ()` -end -leftnull!(d::DiagonalTensorMap; kwargs...) = leftnull!(TensorMap(d); kwargs...) -rightnull!(d::DiagonalTensorMap; kwargs...) = rightnull!(TensorMap(d); kwargs...) - -function tsvd!(d::DiagonalTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) - return _tsvd!(d, alg, trunc, p) -end - -# helper function -function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD}) - InnerProductStyle(d) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) - I = sectortype(d) - dims = SectorDict{I,Int}() - generator = Base.Iterators.map(blocks(d)) do (c, b) - lb = length(b.diag) - U = zerovector!(similar(b.diag, lb, lb)) - V = zerovector!(similar(b.diag, lb, lb)) - p = sortperm(b.diag; by=abs, rev=true) - for (i, pi) in enumerate(p) - U[pi, i] = safesign(b.diag[pi]) - V[i, pi] = 1 - end - Σ = abs.(view(b.diag, p)) - dims[c] = lb - return c => (U, Σ, V) - end - SVDdata = SectorDict(generator) - return SVDdata, dims -end - -eig!(d::DiagonalTensorMap) = d, one(d) -eigh!(d::DiagonalTensorMap{<:Real}) = d, one(d) -eigh!(d::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(real(d.data), d.domain), one(d) - -function LinearAlgebra.svdvals(d::DiagonalTensorMap) - return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d)) -end -function LinearAlgebra.eigvals(d::DiagonalTensorMap) - return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d)) -end - -function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2) - return LinearAlgebra.cond(Diagonal(d.data), p) -end #------------------------------# # Singular value decomposition # #------------------------------# From a41d60a21ac5d32c1a294503a3129f0b44e50660 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 31 Aug 2025 20:26:25 +0200 Subject: [PATCH 069/126] clean up handling AdjointTensorMap --- src/tensors/factorizations/adjoint.jl | 29 ++++++++++ src/tensors/factorizations/factorizations.jl | 54 +++++++++++++++++++ .../factorizations/matrixalgebrakit.jl | 13 ++--- 3 files changed, 86 insertions(+), 10 deletions(-) diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 08154d15e..f47ad9549 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -66,3 +66,32 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) return F end end + +function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos()) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:leftorth!) + return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg'))) +end + +function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos()) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightorth!) + return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg'))) +end + +function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:leftnull!) + return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...)) +end + +function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...) + InnerProductStyle(t) === EuclideanInnerProduct() || + throw_invalid_innerproduct(:rightnull!) + return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...)) +end + +function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) + u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) + return adjoint(vt), adjoint(s), adjoint(u), err +end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 2d6349e51..cd6ea2613 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -56,6 +56,60 @@ end #------------------------------------------------------------------------------------------ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} +# DiagonalTensorMap +# ----------------- +function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...) + @assert alg isa Union{QR,QL} + return one(d), d # TODO: this is only correct for `alg = QR()` or `alg = QL()` +end +function rightorth!(d::DiagonalTensorMap; alg=LQ(), kwargs...) + @assert alg isa Union{LQ,RQ} + return d, one(d) # TODO: this is only correct for `alg = LQ()` or `alg = RQ()` +end +leftnull!(d::DiagonalTensorMap; kwargs...) = leftnull!(TensorMap(d); kwargs...) +rightnull!(d::DiagonalTensorMap; kwargs...) = rightnull!(TensorMap(d); kwargs...) + +function tsvd!(d::DiagonalTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) + return _tsvd!(d, alg, trunc, p) +end + +# helper function +function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD}) + InnerProductStyle(d) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) + I = sectortype(d) + dims = SectorDict{I,Int}() + generator = Base.Iterators.map(blocks(d)) do (c, b) + lb = length(b.diag) + U = zerovector!(similar(b.diag, lb, lb)) + V = zerovector!(similar(b.diag, lb, lb)) + p = sortperm(b.diag; by=abs, rev=true) + for (i, pi) in enumerate(p) + U[pi, i] = safesign(b.diag[pi]) + V[i, pi] = 1 + end + Σ = abs.(view(b.diag, p)) + dims[c] = lb + return c => (U, Σ, V) + end + SVDdata = SectorDict(generator) + return SVDdata, dims +end + +eig!(d::DiagonalTensorMap) = d, one(d) +eigh!(d::DiagonalTensorMap{<:Real}) = d, one(d) +eigh!(d::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(real(d.data), d.domain), one(d) + +function LinearAlgebra.svdvals(d::DiagonalTensorMap) + return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d)) +end +function LinearAlgebra.eigvals(d::DiagonalTensorMap) + return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d)) +end + +function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2) + return LinearAlgebra.cond(Diagonal(d.data), p) +end + #------------------------------# # Singular value decomposition # #------------------------------# diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 2cd40b01a..30ce0284c 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -538,15 +538,8 @@ function initialize_output(::typeof(right_null!), t::AbstractTensorMap) return N end -for f! in (:left_null_svd!, :right_null_svd!) - @eval function $f!(t::AbstractTensorMap, N, alg, ::Nothing=nothing) - foreachblock(t, N) do _, (b, n) - n′ = $f!(b, n, alg) - # deal with the case where the output is not the same as the input - n === n′ || copyto!(n, n′) - return nothing - end - - return N +for (f!, f_svd!) in zip((:left_null!, :right_null!), (:left_null_svd!, :right_null_svd!)) + @eval function $f_svd!(t::AbstractTensorMap, N, alg, ::Nothing=nothing) + return $f!(t, N, alg) end end From b4f884f707d9d195bef546f578f0847506183122 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 31 Aug 2025 20:32:16 +0200 Subject: [PATCH 070/126] remove previous adjoint specializations --- src/tensors/factorizations/adjoint.jl | 29 --------------------------- 1 file changed, 29 deletions(-) diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index f47ad9549..08154d15e 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -66,32 +66,3 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) return F end end - -function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftorth!) - return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg'))) -end - -function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos()) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightorth!) - return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg'))) -end - -function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftnull!) - return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...)) -end - -function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightnull!) - return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...)) -end - -function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) - u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg) - return adjoint(vt), adjoint(s), adjoint(u), err -end From fec8a9c5eeb3b429a4454bb9c4061f2b07bbf5ae Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 4 Sep 2025 10:01:54 -0400 Subject: [PATCH 071/126] move factorizations out into their own test file --- src/tensors/factorizations/adjoint.jl | 10 + src/tensors/factorizations/factorizations.jl | 7 +- src/tensors/factorizations/implementations.jl | 38 +++- src/tensors/factorizations/interface.jl | 2 +- .../factorizations/matrixalgebrakit.jl | 3 +- test/factorizations.jl | 202 ++++++++++++++++++ test/runtests.jl | 1 + test/tensors.jl | 165 -------------- 8 files changed, 248 insertions(+), 180 deletions(-) create mode 100644 test/factorizations.jl diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 08154d15e..5f01d971b 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -66,3 +66,13 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) return F end end +# avoid amgiguity +function initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap, + alg::TruncatedAlgorithm) + return initialize_output(svd_compact!, t, alg.alg) +end +# to fix ambiguity +function svd_trunc!(t::AdjointTensorMap, USVᴴ::Tuple{AdjointTensorMap,DiagonalTensorMap,AdjointTensorMap}, alg::TruncatedAlgorithm) + USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) + return truncate!(svd_trunc!, USVᴴ′, alg.trunc) +end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index cd6ea2613..522efbd9c 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -8,7 +8,8 @@ export tsvd, tsvd!, svdvals, svdvals! export leftorth, leftorth!, rightorth, rightorth! export leftnull, leftnull!, rightnull, rightnull! export copy_oftype, permutedcopy_oftype, one! -export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace +export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD +#export LAPACK_HouseholderQR, LAPACK_HouseholderLQ using ..TensorKit using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock, one! @@ -21,7 +22,9 @@ using TensorOperations: Index2Tuple using MatrixAlgebraKit using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy, NoTruncation, TruncationKeepAbove, TruncationKeepBelow, - TruncationIntersection, TruncationKeepFiltered, DiagonalAlgorithm + TruncationIntersection, TruncationKeepFiltered, PolarViaSVD, + LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR, + LAPACK_HouseholderLQ, DiagonalAlgorithm import MatrixAlgebraKit: default_algorithm, copy_input, check_input, initialize_output, qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index d00703549..49a63b9c1 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -3,6 +3,11 @@ _kindof(::Union{QR,QRpos}) = :qr _kindof(::Union{LQ,LQpos}) = :lq _kindof(::Polar) = :polar +_kindof(::LAPACK_HouseholderQR) = :qr +_kindof(::LAPACK_HouseholderLQ) = :lq +_kindof(::LAPACK_SVDAlgorithm) = :svd +_kindof(::PolarViaSVD) = :polar + leftorth!(t; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) function _leftorth!(t::AbstractTensorMap, alg::Nothing, ; kwargs...) @@ -19,14 +24,16 @@ function _leftorth!(t::AbstractTensorMap, alg::Union{QL,QLpos}; kwargs...) return Q, R end end -function _leftorth!(t, alg::OFA; kwargs...) +function _leftorth!(t, alg::Union{OFA,AbstractAlgorithm}; kwargs...) trunc = isempty(kwargs) ? nothing : (; kwargs...) - Base.depwarn(lazy"$alg is deprecated", :leftorth!) + alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :leftorth!) kind = _kindof(alg) if kind == :svd - alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : + alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) return left_orth!(t; kind, alg_svd, trunc) @@ -40,19 +47,22 @@ function _leftorth!(t, alg::OFA; kwargs...) end end # fallback to MatrixAlgebraKit version -_leftorth!(t, alg; kwargs...) = left_orth!(t; alg, kwargs...) +_leftorth!(t, alg; kwargs...) = left_orth!(t, alg; kwargs...) function leftnull!(t::AbstractTensorMap; - alg::Union{QR,QRpos,SVD,SDD,Nothing}=nothing, kwargs...) + alg::Union{LAPACK_HouseholderQR,LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,QR,QRpos,SVD,SDD,Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:leftnull!) trunc = isempty(kwargs) ? nothing : (; kwargs...) + alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :leftnull!) isnothing(alg) && return left_null!(t; trunc) kind = _kindof(alg) if kind == :svd - alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : + alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) return left_null!(t; kind, alg_svd, trunc) @@ -65,10 +75,12 @@ function leftnull!(t::AbstractTensorMap; end function rightorth!(t::AbstractTensorMap; - alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar,Nothing}=nothing, kwargs...) + alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,RQ,RQpos,SVD,SDD,Polar,Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:rightorth!) trunc = isempty(kwargs) ? nothing : (; kwargs...) + + alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :rightorth!) isnothing(alg) && return right_orth!(t; trunc) @@ -82,7 +94,9 @@ function rightorth!(t::AbstractTensorMap; kind = _kindof(alg) if kind == :svd - alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : + alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) return right_orth!(t; kind, alg_svd, trunc) @@ -97,16 +111,20 @@ function rightorth!(t::AbstractTensorMap; end function rightnull!(t::AbstractTensorMap; - alg::Union{LQ,LQpos,SVD,SDD,Nothing}=nothing, kwargs...) + alg::Union{LAPACK_HouseholderLQ, LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,SVD,SDD,Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:rightnull!) trunc = isempty(kwargs) ? nothing : (; kwargs...) + alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :rightnull!) + isnothing(alg) && return right_null!(t; trunc) kind = _kindof(alg) if kind == :svd - alg_svd = alg === SVD() ? LAPACK_QRIteration() : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : + alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) return right_null!(t; kind, alg_svd, trunc) diff --git a/src/tensors/factorizations/interface.jl b/src/tensors/factorizations/interface.jl index fc757a298..22c3c75d5 100644 --- a/src/tensors/factorizations/interface.jl +++ b/src/tensors/factorizations/interface.jl @@ -30,7 +30,7 @@ equivalent total dimension of the internal vector space is no larger than `χ`. The method `tsvd` also returns the truncation error `ϵ`, computed as the `p` norm of the singular values that were truncated. -THe keyword `alg` can be equal to `SVD()` or `SDD()`, corresponding to the underlying LAPACK +The keyword `alg` can be equal to `SVD()` or `SDD()`, corresponding to the underlying LAPACK algorithm that computes the decomposition (`_gesvd` or `_gesdd`). Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)` diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 30ce0284c..1aa6c630a 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -367,7 +367,6 @@ end # ------------------- const _T_WP = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} const _T_PWᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -using MatrixAlgebraKit: PolarViaSVD function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP, ::AbstractAlgorithm) codomain(t) ≿ domain(t) || @@ -540,6 +539,6 @@ end for (f!, f_svd!) in zip((:left_null!, :right_null!), (:left_null_svd!, :right_null_svd!)) @eval function $f_svd!(t::AbstractTensorMap, N, alg, ::Nothing=nothing) - return $f!(t, N, alg) + return $f!(t, N; alg_svd=alg) end end diff --git a/test/factorizations.jl b/test/factorizations.jl new file mode 100644 index 000000000..c76511bd8 --- /dev/null +++ b/test/factorizations.jl @@ -0,0 +1,202 @@ +spacelist = try + if ENV["CI"] == "true" + println("Detected running on CI") + if Sys.iswindows() + (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂) + elseif Sys.isapple() + (Vtr, Vℤ₃, VfU₁, VfSU₂) + else + (Vtr, VU₁, VCU₁, VSU₂, VfSU₂) + end + else + (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂) + end +catch + (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂) +end + +for V in spacelist + I = sectortype(first(V)) + Istr = TensorKit.type_repr(I) + println("---------------------------------------") + println("Tensors with symmetry: $Istr") + println("---------------------------------------") + @timedtestset "Tensors with symmetry: $Istr" verbose = true begin + V1, V2, V3, V4, V5 = V + @timedtestset "Factorization" begin + W = V1 ⊗ V2 + @testset for T in (Float32, ComplexF64) + # Test both a normal tensor and an adjoint one. + ts = (rand(T, W, W'), rand(T, W, W')') + @testset for t in ts + # test squares and rectangles here + @testset "leftorth with $alg" for alg in + (TensorKit.LAPACK_HouseholderQR(), + TensorKit.LAPACK_HouseholderQR(positive=true), + #TensorKit.QL(), + #TensorKit.QLpos(), + TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), + TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + Q, R = @constinferred leftorth(t; alg=alg) + @test isisometry(Q) + tQR = Q * R + @test tQR ≈ t + end + @testset "leftnull with $alg" for alg in + (TensorKit.LAPACK_HouseholderQR(), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + N = @constinferred leftnull(t; alg=alg) + @test isisometry(N) + @test norm(N' * t) < 100 * eps(norm(t)) + end + @testset "rightorth with $alg" for alg in + (#TensorKit.RQ(), TensorKit.RQpos(), + TensorKit.LAPACK_HouseholderLQ(), + TensorKit.LAPACK_HouseholderLQ(positive=true), + TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), + TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + L, Q = @constinferred rightorth(t; alg=alg) + @test isisometry(Q; side=:right) + @test L * Q ≈ t + end + @testset "rightnull with $alg" for alg in + (TensorKit.LAPACK_HouseholderLQ(), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + M = @constinferred rightnull(t; alg=alg) + @test isisometry(M; side=:right) + @test norm(t * M') < 100 * eps(norm(t)) + end + @testset "tsvd with $alg" for alg in (TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + U, S, V = @constinferred tsvd(t; alg=alg) + @test isisometry(U) + @test isisometry(V; side=:right) + @test U * S * V ≈ t + + s = LinearAlgebra.svdvals(t) + s′ = LinearAlgebra.diag(S) + for (c, b) in s + @test b ≈ s′[c] + end + s = LinearAlgebra.svdvals(t') + s′ = LinearAlgebra.diag(S') + for (c, b) in s + @test b ≈ s′[c] + end + end + @testset "cond and rank" begin + d1 = dim(codomain(t)) + d2 = dim(domain(t)) + @test rank(t) == min(d1, d2) + M = leftnull(t) + @test rank(M) == max(d1, d2) - min(d1, d2) + t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) + @test cond(t3) ≈ one(real(T)) + @test rank(t3) == dim(V1 ⊗ V2) + t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) + t4 = (t4 + t4') / 2 + vals = LinearAlgebra.eigvals(t4) + λmax = maximum(s -> maximum(abs, s), values(vals)) + λmin = minimum(s -> minimum(abs, s), values(vals)) + @test cond(t4) ≈ λmax / λmin + vals = LinearAlgebra.eigvals(t4') + λmax = maximum(s -> maximum(abs, s), values(vals)) + λmin = minimum(s -> minimum(abs, s), values(vals)) + @test cond(t4') ≈ λmax / λmin + end + end + @testset "empty tensor" begin + t = randn(T, V1 ⊗ V2, zero(V1)) + @testset "leftorth with $alg" for alg in + (TensorKit.LAPACK_HouseholderQR(), + TensorKit.LAPACK_HouseholderQR(positive=true), + #TensorKit.QL(), TensorKit.QLpos(), + TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), + TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + Q, R = @constinferred leftorth(t; alg=alg) + @test Q == t + @test dim(Q) == dim(R) == 0 + end + @testset "leftnull with $alg" for alg in + (TensorKit.LAPACK_HouseholderQR(), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + N = @constinferred leftnull(t; alg=alg) + @test isunitary(N) + end + @testset "rightorth with $alg" for alg in + (#TensorKit.RQ(), TensorKit.RQpos(), + TensorKit.LAPACK_HouseholderLQ(), + TensorKit.LAPACK_HouseholderLQ(positive=true), + TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), + TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + L, Q = @constinferred rightorth(copy(t'); alg=alg) + @test Q == t' + @test dim(Q) == dim(L) == 0 + end + @testset "rightnull with $alg" for alg in + (TensorKit.LAPACK_HouseholderLQ(), + TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + M = @constinferred rightnull(copy(t'); alg=alg) + @test isunitary(M) + end + @testset "tsvd with $alg" for alg in (TensorKit.LAPACK_QRIteration(), + TensorKit.LAPACK_DivideAndConquer()) + U, S, V = @constinferred tsvd(t; alg=alg) + @test U == t + @test dim(U) == dim(S) == dim(V) + end + @testset "cond and rank" begin + @test rank(t) == 0 + W2 = zero(V1) * zero(V2) + t2 = rand(W2, W2) + @test rank(t2) == 0 + @test cond(t2) == 0.0 + end + end + @testset "eig and isposdef" begin + t = rand(T, V1, V1) + D, V = eigen(t) + @test t * V ≈ V * D + + d = LinearAlgebra.eigvals(t; sortby=nothing) + d′ = LinearAlgebra.diag(D) + for (c, b) in d + @test b ≈ d′[c] + end + + # Somehow moving these test before the previous one gives rise to errors + # with T=Float32 on x86 platforms. Is this an OpenBLAS issue? + VdV = V' * V + VdV = (VdV + VdV') / 2 + @test isposdef(VdV) + + @test !isposdef(t) # unlikely for non-hermitian map + t2 = (t + t') + D, V = eigen(t2) + @test isisometry(V) + D̃, Ṽ = @constinferred eigh(t2) + @test D ≈ D̃ + @test V ≈ Ṽ + λ = minimum(minimum(real(LinearAlgebra.diag(b))) + for (c, b) in blocks(D)) + @test cond(Ṽ) ≈ one(real(T)) + @test isposdef(t2) == isposdef(λ) + @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) + @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9fafa1d9f..03d7c183f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -121,6 +121,7 @@ if !is_buildkite include("fusiontrees.jl") include("spaces.jl") include("tensors.jl") + include("factorizations.jl") include("diagonal.jl") include("planar.jl") # TODO: remove once we know AD is slow on macOS CI diff --git a/test/tensors.jl b/test/tensors.jl index b827d9f12..796becded 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -1,9 +1,3 @@ -for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VSU₂U₁)#, VSU₃) - V1, V2, V3, V4, V5 = V - @assert V3 * V4 * V2 ≿ V1' * V5' # necessary for leftorth tests - @assert V3 * V4 ≾ V1' * V2' * V5' # necessary for rightorth tests -end - spacelist = try if ENV["CI"] == "true" println("Detected running on CI") @@ -438,165 +432,6 @@ for V in spacelist @test LinearAlgebra.isdiag(D) @test LinearAlgebra.diag(D) == d end - @timedtestset "Factorization" begin - W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 - for T in (Float32, ComplexF64) - # Test both a normal tensor and an adjoint one. - ts = (rand(T, W), rand(T, W)') - for t in ts - @testset "leftorth with $alg" for alg in - (TensorKit.QR(), TensorKit.QRpos(), - TensorKit.QL(), TensorKit.QLpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - Q, R = @constinferred leftorth(t, ((3, 4, 2), (1, 5)); alg=alg) - @test isisometry(Q) - @test Q * R ≈ permute(t, ((3, 4, 2), (1, 5))) - end - @testset "leftnull with $alg" for alg in - (TensorKit.QR(), TensorKit.SVD(), - TensorKit.SDD()) - N = @constinferred leftnull(t, ((3, 4, 2), (1, 5)); alg=alg) - @test isisometry(N) - @test norm(N' * permute(t, ((3, 4, 2), (1, 5)))) < - 100 * eps(norm(t)) - end - @testset "rightorth with $alg" for alg in - (TensorKit.RQ(), TensorKit.RQpos(), - TensorKit.LQ(), TensorKit.LQpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - L, Q = @constinferred rightorth(t, ((3, 4), (2, 1, 5)); alg=alg) - @test isisometry(Q; side=:right) - @test L * Q ≈ permute(t, ((3, 4), (2, 1, 5))) - end - @testset "rightnull with $alg" for alg in - (TensorKit.LQ(), TensorKit.SVD(), - TensorKit.SDD()) - M = @constinferred rightnull(t, ((3, 4), (2, 1, 5)); alg=alg) - @test isisometry(M; side=:right) - @test norm(permute(t, ((3, 4), (2, 1, 5))) * M') < - 100 * eps(norm(t)) - end - @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) - U, S, V = @constinferred tsvd(t, ((3, 4, 2), (1, 5)); alg=alg) - @test isisometry(U) - @test isisometry(V; side=:right) - t2 = permute(t, ((3, 4, 2), (1, 5))) - @test U * S * V ≈ t2 - - s = LinearAlgebra.svdvals(t2) - s′ = LinearAlgebra.diag(S) - for (c, b) in s - @test b ≈ s′[c] - end - s = LinearAlgebra.svdvals(t2') - s′ = LinearAlgebra.diag(S') - for (c, b) in s - @test b ≈ s′[c] - end - end - @testset "cond and rank" begin - t2 = permute(t, ((3, 4, 2), (1, 5))) - d1 = dim(codomain(t2)) - d2 = dim(domain(t2)) - @test rank(t2) == min(d1, d2) - M = leftnull(t2) - @test rank(M) == max(d1, d2) - min(d1, d2) - t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) - @test cond(t3) ≈ one(real(T)) - @test rank(t3) == dim(V1 ⊗ V2) - t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) - t4 = (t4 + t4') / 2 - vals = LinearAlgebra.eigvals(t4) - λmax = maximum(s -> maximum(abs, s), values(vals)) - λmin = minimum(s -> minimum(abs, s), values(vals)) - @test cond(t4) ≈ λmax / λmin - vals = LinearAlgebra.eigvals(t4') - λmax = maximum(s -> maximum(abs, s), values(vals)) - λmin = minimum(s -> minimum(abs, s), values(vals)) - @test cond(t4') ≈ λmax / λmin - end - end - @testset "empty tensor" begin - t = randn(T, V1 ⊗ V2, zero(V1)) - @testset "leftorth with $alg" for alg in - (TensorKit.QR(), TensorKit.QRpos(), - TensorKit.QL(), TensorKit.QLpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - Q, R = @constinferred leftorth(t; alg=alg) - @test Q == t - @test dim(Q) == dim(R) == 0 - end - @testset "leftnull with $alg" for alg in - (TensorKit.QR(), TensorKit.SVD(), - TensorKit.SDD()) - N = @constinferred leftnull(t; alg=alg) - @test isunitary(N) - end - @testset "rightorth with $alg" for alg in - (TensorKit.RQ(), TensorKit.RQpos(), - TensorKit.LQ(), TensorKit.LQpos(), - TensorKit.Polar(), TensorKit.SVD(), - TensorKit.SDD()) - L, Q = @constinferred rightorth(copy(t'); alg=alg) - @test Q == t' - @test dim(Q) == dim(L) == 0 - end - @testset "rightnull with $alg" for alg in - (TensorKit.LQ(), TensorKit.SVD(), - TensorKit.SDD()) - M = @constinferred rightnull(copy(t'); alg=alg) - @test isunitary(M) - end - @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) - U, S, V = @constinferred tsvd(t; alg=alg) - @test U == t - @test dim(U) == dim(S) == dim(V) - end - @testset "cond and rank" begin - @test rank(t) == 0 - W2 = zero(V1) * zero(V2) - t2 = rand(W2, W2) - @test rank(t2) == 0 - @test cond(t2) == 0.0 - end - end - t = rand(T, V1 ⊗ V1' ⊗ V2 ⊗ V2') - @testset "eig and isposdef" begin - D, V = eigen(t, ((1, 3), (2, 4))) - t2 = permute(t, ((1, 3), (2, 4))) - @test t2 * V ≈ V * D - - d = LinearAlgebra.eigvals(t2; sortby=nothing) - d′ = LinearAlgebra.diag(D) - for (c, b) in d - @test b ≈ d′[c] - end - - # Somehow moving these test before the previous one gives rise to errors - # with T=Float32 on x86 platforms. Is this an OpenBLAS issue? - VdV = V' * V - VdV = (VdV + VdV') / 2 - @test isposdef(VdV) - - @test !isposdef(t2) # unlikely for non-hermitian map - t2 = (t2 + t2') - D, V = eigen(t2) - @test isisometry(V) - D̃, Ṽ = @constinferred eigh(t2) - @test D ≈ D̃ - @test V ≈ Ṽ - λ = minimum(minimum(real(LinearAlgebra.diag(b))) - for (c, b) in blocks(D)) - @test cond(Ṽ) ≈ one(real(T)) - @test isposdef(t2) == isposdef(λ) - @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) - @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) - end - end - end @timedtestset "Tensor truncation" begin for T in (Float32, ComplexF64) for p in (1, 2, 3, Inf) From e82f8b61483c70f1140d1e433ef6d1254ab3b1b1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 5 Sep 2025 11:13:22 -0400 Subject: [PATCH 072/126] fixup! clean up handling DiagonalTensorMap --- src/tensors/factorizations/diagonal.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index cd70a1d5f..543055406 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -2,11 +2,11 @@ # ----------------- _repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data) -for f in [ +for f in ( :svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, :left_polar, :right_polar, - ] + ) @eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) end From 6b80f6208b2a1803860b735ed58b5a95471813b5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 5 Sep 2025 11:48:41 -0400 Subject: [PATCH 073/126] little bit of cleanup --- src/tensors/factorizations/factorizations.jl | 28 +++++--------------- src/tensors/linalg.jl | 1 - 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 522efbd9c..bb0e0679b 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -9,7 +9,6 @@ export leftorth, leftorth!, rightorth, rightorth! export leftnull, leftnull!, rightnull, rightnull! export copy_oftype, permutedcopy_oftype, one! export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD -#export LAPACK_HouseholderQR, LAPACK_HouseholderLQ using ..TensorKit using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock, one! @@ -114,30 +113,17 @@ function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2) end #------------------------------# -# Singular value decomposition # +# LinearAlgebra overloads #------------------------------# -function LinearAlgebra.svdvals!(t::TensorMap{<:RealOrComplexFloat}) - return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t)) -end -LinearAlgebra.svdvals!(t::AdjointTensorMap) = svdvals!(adjoint(t)) - -#--------------------------# -# Eigenvalue decomposition # -#--------------------------# - -function LinearAlgebra.eigvals!(t::TensorMap{<:RealOrComplexFloat}; kwargs...) - return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...)) - for (c, b) in blocks(t)) -end -function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwargs...) - return SectorDict(c => conj!(complex(LinearAlgebra.eigvals!(b; kwargs...))) - for (c, b) in blocks(t)) -end +LinearAlgebra.svdvals(t::AbstractTensorMap) = diagview(svd_vals(t)) +LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t)) +LinearAlgebra.eigvals(t::AbstractTensorMap) = diagview(eigvals(t)) +LinearAlgebra.eigvals!(t::AbstractTensorMap) = diagview(eigvals!(t)) #--------------------------------------------------# # Checks for hermiticity and positive definiteness # #--------------------------------------------------# -function LinearAlgebra.ishermitian(t::TensorMap) +function LinearAlgebra.ishermitian(t::AbstractTensorMap) domain(t) == codomain(t) || return false InnerProductStyle(t) === EuclideanInnerProduct() || return false # hermiticity only defined for euclidean for (c, b) in blocks(t) @@ -146,7 +132,7 @@ function LinearAlgebra.ishermitian(t::TensorMap) return true end -function LinearAlgebra.isposdef!(t::TensorMap) +function LinearAlgebra.isposdef!(t::AbstractTensorMap) domain(t) == codomain(t) || throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index b52e8b4bc..5afa3e031 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -179,7 +179,6 @@ end # Diagonal tensors # ---------------- -# TODO: consider adding a specialised DiagonalTensorMap type function LinearAlgebra.diag(t::AbstractTensorMap) return SectorDict(c => LinearAlgebra.diag(b) for (c, b) in blocks(t)) end From 1ad76d4d457496ae0cf6c4174c33c230aa1ae555 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 5 Sep 2025 11:53:47 -0400 Subject: [PATCH 074/126] foreachblock --- src/tensors/diagonal.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index abdfeebb7..27b6bf372 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -282,16 +282,16 @@ end function LinearAlgebra.lmul!(D::DiagonalTensorMap, t::AbstractTensorMap) domain(D) == codomain(t) || throw(SpaceMismatch()) - for (c, b) in blocks(t) - lmul!(block(D, c), b) + foreachblock(D, t) do c, bs + lmul!(bs...) end return t end function LinearAlgebra.rmul!(t::AbstractTensorMap, D::DiagonalTensorMap) codomain(D) == domain(t) || throw(SpaceMismatch()) - for (c, b) in blocks(t) - rmul!(b, block(D, c)) + foreachblock(t, D) do c, bs + rmul!(bs...) end return t end From 9a84e88058beffd97b432e08698a26a78ceff76b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 5 Sep 2025 11:55:20 -0400 Subject: [PATCH 075/126] remove unused functions --- src/auxiliary/linalg.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index 8a20eb8ac..e8c595172 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -46,5 +46,3 @@ Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg const OFA = OrthogonalFactorizationAlgorithm const SVDAlg = Union{SVD,SDD} -safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s)) -safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) From 5a550be0c25a492c37dbddd494f3bad888c14752 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 5 Sep 2025 11:58:19 -0400 Subject: [PATCH 076/126] formatter --- src/auxiliary/linalg.jl | 1 - src/tensors/diagonal.jl | 4 +-- src/tensors/factorizations/adjoint.jl | 4 ++- src/tensors/factorizations/diagonal.jl | 27 ++++++++++------- src/tensors/factorizations/implementations.jl | 30 +++++++++++-------- test/factorizations.jl | 18 ++++++----- 6 files changed, 49 insertions(+), 35 deletions(-) diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index e8c595172..4a3bf9b15 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -45,4 +45,3 @@ Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg const OFA = OrthogonalFactorizationAlgorithm const SVDAlg = Union{SVD,SDD} - diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 27b6bf372..268f7f24f 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -283,7 +283,7 @@ end function LinearAlgebra.lmul!(D::DiagonalTensorMap, t::AbstractTensorMap) domain(D) == codomain(t) || throw(SpaceMismatch()) foreachblock(D, t) do c, bs - lmul!(bs...) + return lmul!(bs...) end return t end @@ -291,7 +291,7 @@ end function LinearAlgebra.rmul!(t::AbstractTensorMap, D::DiagonalTensorMap) codomain(D) == domain(t) || throw(SpaceMismatch()) foreachblock(t, D) do c, bs - rmul!(bs...) + return rmul!(bs...) end return t end diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 5f01d971b..49f36777a 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -72,7 +72,9 @@ function initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap, return initialize_output(svd_compact!, t, alg.alg) end # to fix ambiguity -function svd_trunc!(t::AdjointTensorMap, USVᴴ::Tuple{AdjointTensorMap,DiagonalTensorMap,AdjointTensorMap}, alg::TruncatedAlgorithm) +function svd_trunc!(t::AdjointTensorMap, + USVᴴ::Tuple{AdjointTensorMap,DiagonalTensorMap,AdjointTensorMap}, + alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) return truncate!(svd_trunc!, USVᴴ′, alg.trunc) end diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index 543055406..25347f482 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -2,26 +2,28 @@ # ----------------- _repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data) -for f in ( - :svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, - :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, - :eigh_trunc, :eigh_vals, :left_polar, :right_polar, - ) +for f in (:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, + :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, + :eigh_trunc, :eigh_vals, :left_polar, :right_polar) @eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) end for f! in (:qr_full!, :qr_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!) - @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) + @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) return d, similar(d) end end for f! in (:lq_full!, :lq_compact!) - @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) + @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) return similar(d), d end end -for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!) +for f! in + (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, + :eigh_trunc!) @eval function $f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm) check_input($f!, d, F, alg) $f!(_repack_diagonal(d), _repack_diagonal.(F), alg) @@ -30,7 +32,8 @@ for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_tr end for f! in (:qr_full!, :qr_compact!) - @eval function check_input(::typeof($f!), d::AbstractTensorMap, (Q, R)::_T_QR, ::DiagonalAlgorithm) + @eval function check_input(::typeof($f!), d::AbstractTensorMap, (Q, R)::_T_QR, + ::DiagonalAlgorithm) @assert d isa DiagonalTensorMap @assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap @check_scalar Q d @@ -43,7 +46,8 @@ for f! in (:qr_full!, :qr_compact!) end for f! in (:lq_full!, :lq_compact!) - @eval function check_input(::typeof($f!), d::AbstractTensorMap, (L, Q)::_T_LQ, ::DiagonalAlgorithm) + @eval function check_input(::typeof($f!), d::AbstractTensorMap, (L, Q)::_T_LQ, + ::DiagonalAlgorithm) @assert d isa DiagonalTensorMap @assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap @check_scalar Q d @@ -64,7 +68,8 @@ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) $f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg) return V end - @eval function initialize_output(::typeof($f!), d::DiagonalTensorMap, alg::DiagonalAlgorithm) + @eval function initialize_output(::typeof($f!), d::DiagonalTensorMap, + alg::DiagonalAlgorithm) data = initialize_output($f!, _repack_diagonal(d), alg) return DiagonalTensorMap(data, d.domain) end diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index 49a63b9c1..c898d4f33 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -31,8 +31,8 @@ function _leftorth!(t, alg::Union{OFA,AbstractAlgorithm}; kwargs...) kind = _kindof(alg) if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) @@ -50,7 +50,9 @@ end _leftorth!(t, alg; kwargs...) = left_orth!(t, alg; kwargs...) function leftnull!(t::AbstractTensorMap; - alg::Union{LAPACK_HouseholderQR,LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,QR,QRpos,SVD,SDD,Nothing}=nothing, kwargs...) + alg::Union{LAPACK_HouseholderQR,LAPACK_QRIteration, + LAPACK_DivideAndConquer,PolarViaSVD,QR,QRpos,SVD,SDD,Nothing}=nothing, + kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:leftnull!) trunc = isempty(kwargs) ? nothing : (; kwargs...) @@ -60,8 +62,8 @@ function leftnull!(t::AbstractTensorMap; kind = _kindof(alg) if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) @@ -75,11 +77,13 @@ function leftnull!(t::AbstractTensorMap; end function rightorth!(t::AbstractTensorMap; - alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,RQ,RQpos,SVD,SDD,Polar,Nothing}=nothing, kwargs...) + alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, + LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,RQ,RQpos,SVD, + SDD,Polar,Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:rightorth!) trunc = isempty(kwargs) ? nothing : (; kwargs...) - + alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :rightorth!) isnothing(alg) && return right_orth!(t; trunc) @@ -94,8 +98,8 @@ function rightorth!(t::AbstractTensorMap; kind = _kindof(alg) if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) @@ -111,7 +115,9 @@ function rightorth!(t::AbstractTensorMap; end function rightnull!(t::AbstractTensorMap; - alg::Union{LAPACK_HouseholderLQ, LAPACK_QRIteration, LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,SVD,SDD,Nothing}=nothing, kwargs...) + alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, + LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,SVD,SDD, + Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:rightnull!) trunc = isempty(kwargs) ? nothing : (; kwargs...) @@ -122,8 +128,8 @@ function rightnull!(t::AbstractTensorMap; kind = _kindof(alg) if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : + alg_svd = alg === LAPACK_QRIteration() ? alg : + alg === LAPACK_DivideAndConquer() ? alg : alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) diff --git a/test/factorizations.jl b/test/factorizations.jl index c76511bd8..e049dc825 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -32,7 +32,8 @@ for V in spacelist # test squares and rectangles here @testset "leftorth with $alg" for alg in (TensorKit.LAPACK_HouseholderQR(), - TensorKit.LAPACK_HouseholderQR(positive=true), + TensorKit.LAPACK_HouseholderQR(; + positive=true), #TensorKit.QL(), #TensorKit.QLpos(), TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), @@ -53,9 +54,9 @@ for V in spacelist @test norm(N' * t) < 100 * eps(norm(t)) end @testset "rightorth with $alg" for alg in - (#TensorKit.RQ(), TensorKit.RQpos(), - TensorKit.LAPACK_HouseholderLQ(), - TensorKit.LAPACK_HouseholderLQ(positive=true), + (TensorKit.LAPACK_HouseholderLQ(), + TensorKit.LAPACK_HouseholderLQ(; + positive=true), TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), TensorKit.LAPACK_QRIteration(), @@ -115,7 +116,8 @@ for V in spacelist t = randn(T, V1 ⊗ V2, zero(V1)) @testset "leftorth with $alg" for alg in (TensorKit.LAPACK_HouseholderQR(), - TensorKit.LAPACK_HouseholderQR(positive=true), + TensorKit.LAPACK_HouseholderQR(; + positive=true), #TensorKit.QL(), TensorKit.QLpos(), TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), @@ -133,9 +135,9 @@ for V in spacelist @test isunitary(N) end @testset "rightorth with $alg" for alg in - (#TensorKit.RQ(), TensorKit.RQpos(), - TensorKit.LAPACK_HouseholderLQ(), - TensorKit.LAPACK_HouseholderLQ(positive=true), + (TensorKit.LAPACK_HouseholderLQ(), + TensorKit.LAPACK_HouseholderLQ(; + positive=true), TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), TensorKit.LAPACK_QRIteration(), From 55f8545334e91bf96afdb9546cd1010b99f82b8b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 5 Sep 2025 12:00:52 -0400 Subject: [PATCH 077/126] more cleanup --- src/TensorKit.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 038bdce0e..bf359e8de 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -222,8 +222,6 @@ include("tensors/braidingtensor.jl") include("tensors/factorizations/factorizations.jl") using .Factorizations -# include("tensors/factorizations/matrixalgebrakit.jl") -# include("tensors/truncation.jl") # # Planar macros and related functionality # #----------------------------------------- From 2911c0c83a656f6105e0ea82f34e0023b7cc3dd3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 8 Sep 2025 14:32:51 -0400 Subject: [PATCH 078/126] Updates for Diagonal and rectangular array tests --- src/tensors/factorizations/diagonal.jl | 19 ++++++++++++- src/tensors/factorizations/factorizations.jl | 6 ++-- .../factorizations/matrixalgebrakit.jl | 28 +++++++++++++++++-- test/factorizations.jl | 26 ++++++++--------- 4 files changed, 60 insertions(+), 19 deletions(-) diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index 25347f482..67084d589 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -8,17 +8,34 @@ for f in (:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, @eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) end -for f! in (:qr_full!, :qr_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!) +for f! in (:eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!) @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) return d, similar(d) end end + +for f! in (:qr_full!, :qr_compact!) + @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) + return d, similar(d) + end + # to avoid ambiguities + @eval function initialize_output(::typeof($f!), d::AdjointTensorMap, + ::DiagonalAlgorithm) + return d, similar(d) + end +end for f! in (:lq_full!, :lq_compact!) @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) return similar(d), d end + # to avoid ambiguities + @eval function initialize_output(::typeof($f!), d::AdjointTensorMap, + ::DiagonalAlgorithm) + return similar(d), d + end end for f! in diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index bb0e0679b..e84cef284 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -115,10 +115,10 @@ end #------------------------------# # LinearAlgebra overloads #------------------------------# -LinearAlgebra.svdvals(t::AbstractTensorMap) = diagview(svd_vals(t)) +#LinearAlgebra.svdvals(t::AbstractTensorMap) = diagview(svd_vals(t)) LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t)) -LinearAlgebra.eigvals(t::AbstractTensorMap) = diagview(eigvals(t)) -LinearAlgebra.eigvals!(t::AbstractTensorMap) = diagview(eigvals!(t)) +#LinearAlgebra.eigvals(t::AbstractTensorMap) = diagview(eig_vals(t)) +LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = diagview(eig_vals!(t)) #--------------------------------------------------# # Checks for hermiticity and positive definiteness # diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 1aa6c630a..d6eb22a93 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -45,7 +45,7 @@ for f! in (:qr_compact!, :qr_full!, end # Handle these separately because single output instead of tuple -for f! in (:qr_null!, :lq_null!, :svd_vals!, :eig_vals!, :eigh_vals!) +for f! in (:qr_null!, :lq_null!) @eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) check_input($f!, t, N, alg) @@ -60,6 +60,22 @@ for f! in (:qr_null!, :lq_null!, :svd_vals!, :eig_vals!, :eigh_vals!) end end +# Handle these separately because single output instead of tuple +for f! in (:svd_vals!, :eig_vals!, :eigh_vals!) + @eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) + check_input($f!, t, N, alg) + + foreachblock(t, N) do _, (b, n) + n′ = $f!(b, n.diag, alg) + # deal with the case where the output is not the same as the input + n.diag === n′ || copyto!(n, diagview(n′)) + return nothing + end + + return N + end +end + # Singular value decomposition # ---------------------------- const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} @@ -98,11 +114,19 @@ end function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict, ::AbstractAlgorithm) @check_scalar S t real - V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) @check_space(S, V_cod ← V_dom) return nothing end +function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, + ::AbstractAlgorithm) + @check_scalar D t real + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) + @check_space(D, V_cod ← V_dom) + return nothing +end + function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm) V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) diff --git a/test/factorizations.jl b/test/factorizations.jl index e049dc825..ab2284837 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -27,9 +27,8 @@ for V in spacelist W = V1 ⊗ V2 @testset for T in (Float32, ComplexF64) # Test both a normal tensor and an adjoint one. - ts = (rand(T, W, W'), rand(T, W, W')') + ts = (rand(T, W, W'), rand(T, W, W')', rand(T, V1, W'), rand(T, V1, W')') @testset for t in ts - # test squares and rectangles here @testset "leftorth with $alg" for alg in (TensorKit.LAPACK_HouseholderQR(), TensorKit.LAPACK_HouseholderQR(; @@ -40,10 +39,10 @@ for V in spacelist TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), TensorKit.LAPACK_QRIteration(), TensorKit.LAPACK_DivideAndConquer()) + (codomain(t) ≾ domain(t)) && alg isa TensorKit.PolarViaSVD && continue Q, R = @constinferred leftorth(t; alg=alg) @test isisometry(Q) - tQR = Q * R - @test tQR ≈ t + @test Q * R ≈ t end @testset "leftnull with $alg" for alg in (TensorKit.LAPACK_HouseholderQR(), @@ -61,6 +60,7 @@ for V in spacelist TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), TensorKit.LAPACK_QRIteration(), TensorKit.LAPACK_DivideAndConquer()) + (domain(t) ≾ codomain(t)) && alg isa TensorKit.PolarViaSVD && continue L, Q = @constinferred rightorth(t; alg=alg) @test isisometry(Q; side=:right) @test L * Q ≈ t @@ -80,28 +80,28 @@ for V in spacelist @test isisometry(V; side=:right) @test U * S * V ≈ t - s = LinearAlgebra.svdvals(t) + s = LinearAlgebra.svdvals(t) s′ = LinearAlgebra.diag(S) for (c, b) in s @test b ≈ s′[c] end - s = LinearAlgebra.svdvals(t') + s = LinearAlgebra.svdvals(t') s′ = LinearAlgebra.diag(S') for (c, b) in s @test b ≈ s′[c] end end @testset "cond and rank" begin - d1 = dim(codomain(t)) - d2 = dim(domain(t)) + d1 = dim(codomain(t)) + d2 = dim(domain(t)) @test rank(t) == min(d1, d2) - M = leftnull(t) - @test rank(M) == max(d1, d2) - min(d1, d2) - t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) + M = leftnull(t) + @test rank(M) + rank(t) == d1 + t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) @test cond(t3) ≈ one(real(T)) @test rank(t3) == dim(V1 ⊗ V2) - t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) - t4 = (t4 + t4') / 2 + t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) + t4 = (t4 + t4') / 2 vals = LinearAlgebra.eigvals(t4) λmax = maximum(s -> maximum(abs, s), values(vals)) λmin = minimum(s -> minimum(abs, s), values(vals)) From 847b67bd0b0217151962d9063ce0c18a3da96054 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 8 Sep 2025 14:45:08 -0400 Subject: [PATCH 079/126] Format fix --- test/factorizations.jl | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/test/factorizations.jl b/test/factorizations.jl index ab2284837..b368b6f95 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -39,7 +39,8 @@ for V in spacelist TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), TensorKit.LAPACK_QRIteration(), TensorKit.LAPACK_DivideAndConquer()) - (codomain(t) ≾ domain(t)) && alg isa TensorKit.PolarViaSVD && continue + (codomain(t) ≾ domain(t)) && alg isa TensorKit.PolarViaSVD && + continue Q, R = @constinferred leftorth(t; alg=alg) @test isisometry(Q) @test Q * R ≈ t @@ -60,7 +61,8 @@ for V in spacelist TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), TensorKit.LAPACK_QRIteration(), TensorKit.LAPACK_DivideAndConquer()) - (domain(t) ≾ codomain(t)) && alg isa TensorKit.PolarViaSVD && continue + (domain(t) ≾ codomain(t)) && alg isa TensorKit.PolarViaSVD && + continue L, Q = @constinferred rightorth(t; alg=alg) @test isisometry(Q; side=:right) @test L * Q ≈ t @@ -80,28 +82,28 @@ for V in spacelist @test isisometry(V; side=:right) @test U * S * V ≈ t - s = LinearAlgebra.svdvals(t) + s = LinearAlgebra.svdvals(t) s′ = LinearAlgebra.diag(S) for (c, b) in s @test b ≈ s′[c] end - s = LinearAlgebra.svdvals(t') + s = LinearAlgebra.svdvals(t') s′ = LinearAlgebra.diag(S') for (c, b) in s @test b ≈ s′[c] end end @testset "cond and rank" begin - d1 = dim(codomain(t)) - d2 = dim(domain(t)) + d1 = dim(codomain(t)) + d2 = dim(domain(t)) @test rank(t) == min(d1, d2) - M = leftnull(t) + M = leftnull(t) @test rank(M) + rank(t) == d1 - t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) + t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) @test cond(t3) ≈ one(real(T)) @test rank(t3) == dim(V1 ⊗ V2) - t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) - t4 = (t4 + t4') / 2 + t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) + t4 = (t4 + t4') / 2 vals = LinearAlgebra.eigvals(t4) λmax = maximum(s -> maximum(abs, s), values(vals)) λmin = minimum(s -> minimum(abs, s), values(vals)) From 2e91d9d6f647edc252c528ee6096db4a08c65cce Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Sep 2025 06:46:20 -0400 Subject: [PATCH 080/126] Fix diagonal factorizations --- src/tensors/factorizations/diagonal.jl | 61 ++++++++++++++++++- src/tensors/factorizations/implementations.jl | 5 +- .../factorizations/matrixalgebrakit.jl | 20 +++++- test/diagonal.jl | 10 +-- 4 files changed, 88 insertions(+), 8 deletions(-) diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index 67084d589..86f40e1d5 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -8,13 +8,24 @@ for f in (:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, @eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) end -for f! in (:eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!) +for f! in (:eig_full!, :eig_trunc!) @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) return d, similar(d) end end +for f! in (:eigh_full!, :eigh_trunc!) + @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) + if scalartype(d) <: Real + return d, similar(d) + else + return similar(d, real(scalartype(d))), similar(d) + end + end +end + for f! in (:qr_full!, :qr_compact!) @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm) @@ -40,7 +51,7 @@ end for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, - :eigh_trunc!) + :eigh_trunc!, :right_orth!, :left_orth!) @eval function $f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm) check_input($f!, d, F, alg) $f!(_repack_diagonal(d), _repack_diagonal.(F), alg) @@ -92,9 +103,55 @@ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) end end +function check_input(::typeof(eig_full!), t::DiagonalTensorMap, (D, V)::_T_DV, + ::DiagonalAlgorithm) + domain(t) == codomain(t) || + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + + # scalartype checks + @check_scalar D t + @check_scalar V t + + # space checks + @check_space D space(t) + @check_space V space(t) + + return nothing +end + +function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, (D, V)::_T_DV, + ::DiagonalAlgorithm) + domain(t) == codomain(t) || + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + + # scalartype checks + @check_scalar D t real + @check_scalar V t + + # space checks + @check_space D space(t) + @check_space V space(t) + + return nothing +end + function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, ::DiagonalAlgorithm) @check_scalar D t @check_space D space(t) return nothing end + +function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, + ::DiagonalAlgorithm) + @check_scalar D t real + @check_space D space(t) + return nothing +end + +function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, + ::DiagonalAlgorithm) + @check_scalar D t real + @check_space D space(t) + return nothing +end diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index c898d4f33..f21c2ba93 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -7,6 +7,7 @@ _kindof(::LAPACK_HouseholderQR) = :qr _kindof(::LAPACK_HouseholderLQ) = :lq _kindof(::LAPACK_SVDAlgorithm) = :svd _kindof(::PolarViaSVD) = :polar +_kindof(::DiagonalAlgorithm) = :svd leftorth!(t; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) @@ -33,6 +34,7 @@ function _leftorth!(t, alg::Union{OFA,AbstractAlgorithm}; kwargs...) if kind == :svd alg_svd = alg === LAPACK_QRIteration() ? alg : alg === LAPACK_DivideAndConquer() ? alg : + alg === DiagonalAlgorithm() ? alg : alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) @@ -78,7 +80,7 @@ end function rightorth!(t::AbstractTensorMap; alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, - LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,RQ,RQpos,SVD, + LAPACK_DivideAndConquer,DiagonalAlgorithm,PolarViaSVD,LQ,LQpos,RQ,RQpos,SVD, SDD,Polar,Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:rightorth!) @@ -100,6 +102,7 @@ function rightorth!(t::AbstractTensorMap; if kind == :svd alg_svd = alg === LAPACK_QRIteration() ? alg : alg === LAPACK_DivideAndConquer() ? alg : + alg === DiagonalAlgorithm() ? alg : alg === SVD() ? LAPACK_QRIteration() : alg === SDD() ? LAPACK_DivideAndConquer() : throw(ArgumentError(lazy"Unknown algorithm $alg")) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index d6eb22a93..c47cfb548 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -197,7 +197,25 @@ function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV, : return nothing end -function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, ::AbstractAlgorithm) +function check_input(::typeof(eig_full!), t::DiagonalTensorMap, (D, V)::_T_DV, + ::AbstractAlgorithm) + domain(t) == codomain(t) || + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + + # scalartype checks + @check_scalar D t + @check_scalar V t + + # space checks + V_D = fuse(domain(t)) + @check_space(D, V_D ← V_D) + @check_space(V, codomain(t) ← V_D) + + return nothing +end + +function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, + ::AbstractAlgorithm) @check_scalar D t real V_D = fuse(domain(t)) @check_space(D, V_D ← V_D) diff --git a/test/diagonal.jl b/test/diagonal.jl index 0f67836dd..688e933f9 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -49,7 +49,9 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @test norm(zerovector!(t)) == 0 @test norm(one!(t)) ≈ sqrt(dim(V)) @test one!(t) == id(V) - @test norm(one!(t) - id(V)) == 0 + if T != BigFloat # seems broken for now + @test norm(one!(t) - id(V)) == 0 + end t1 = DiagonalTensorMap(rand(T, reduceddim(V)), V) t2 = DiagonalTensorMap(rand(T, reduceddim(V)), V) @@ -211,7 +213,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @test V2 == one(t) @test t2 * V2 ≈ V2 * D2 end - @testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QL()) + @testset "leftorth with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) Q, R = @constinferred leftorth(t; alg=alg) QdQ = Q' * Q @test QdQ ≈ one(QdQ) @@ -220,7 +222,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @test isposdef(R) end end - @testset "rightorth with $alg" for alg in (TensorKit.RQ(), TensorKit.LQ()) + @testset "rightorth with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) L, Q = @constinferred rightorth(t; alg=alg) QQd = Q * Q' @test QQd ≈ one(QQd) @@ -229,7 +231,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @test isposdef(L) end end - @testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD()) + @testset "tsvd with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) U, S, Vᴴ = @constinferred tsvd(t; alg=alg) UdU = U' * U @test UdU ≈ one(UdU) From 9297e4693107aa891161c647120e97e75d863f85 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Sep 2025 06:53:36 -0400 Subject: [PATCH 081/126] Format again --- src/tensors/factorizations/diagonal.jl | 4 ++-- src/tensors/factorizations/implementations.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index 86f40e1d5..eb34aab48 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -114,7 +114,7 @@ function check_input(::typeof(eig_full!), t::DiagonalTensorMap, (D, V)::_T_DV, # space checks @check_space D space(t) - @check_space V space(t) + @check_space V space(t) return nothing end @@ -130,7 +130,7 @@ function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, (D, V)::_T_DV, # space checks @check_space D space(t) - @check_space V space(t) + @check_space V space(t) return nothing end diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl index f21c2ba93..a25ed22d2 100644 --- a/src/tensors/factorizations/implementations.jl +++ b/src/tensors/factorizations/implementations.jl @@ -80,7 +80,8 @@ end function rightorth!(t::AbstractTensorMap; alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, - LAPACK_DivideAndConquer,DiagonalAlgorithm,PolarViaSVD,LQ,LQpos,RQ,RQpos,SVD, + LAPACK_DivideAndConquer,DiagonalAlgorithm,PolarViaSVD,LQ, + LQpos,RQ,RQpos,SVD, SDD,Polar,Nothing}=nothing, kwargs...) InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:rightorth!) From a2c58b676ee889df033b503c820e5b3eca7a7e61 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Sep 2025 16:32:00 -0400 Subject: [PATCH 082/126] Support QR and LQ --- Project.toml | 3 ++ src/TensorKit.jl | 2 ++ src/tensors/factorizations/factorizations.jl | 4 +++ test/factorizations.jl | 30 ++++++++++++++++++++ 4 files changed, 39 insertions(+) diff --git a/Project.toml b/Project.toml index 679b217b5..3d71eca49 100644 --- a/Project.toml +++ b/Project.toml @@ -63,3 +63,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"] + +[sources] +MatrixAlgebraKit = {url="https://github.com/QuantumKitHub/MatrixAlgebraKit.jl", rev="ksh/copyfix"} diff --git a/src/TensorKit.jl b/src/TensorKit.jl index bf359e8de..89fc6de02 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -73,6 +73,8 @@ export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby! export leftorth, rightorth, leftnull, rightnull, leftorth!, rightorth!, leftnull!, rightnull!, left_polar, left_polar!, right_polar, right_polar!, + qr_full, qr_compact, qr_null, lq_full, lq_compact, lq_null, + qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!, tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index e84cef284..2c78310b6 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -7,6 +7,10 @@ export eig, eig!, eigh, eigh! export tsvd, tsvd!, svdvals, svdvals! export leftorth, leftorth!, rightorth, rightorth! export leftnull, leftnull!, rightnull, rightnull! +export qr_full, qr_compact, qr_null +export qr_full!, qr_compact!, qr_null! +export lq_full, lq_compact, lq_null +export lq_full!, lq_compact!, lq_null! export copy_oftype, permutedcopy_oftype, one! export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD diff --git a/test/factorizations.jl b/test/factorizations.jl index b368b6f95..cf8521af6 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -29,6 +29,36 @@ for V in spacelist # Test both a normal tensor and an adjoint one. ts = (rand(T, W, W'), rand(T, W, W')', rand(T, V1, W'), rand(T, V1, W')') @testset for t in ts + @testset "qr_full" begin + Q, R = @constinferred qr_full(t) + @test isisometry(Q) + @test Q * R ≈ t + end + @testset "qr_compact" begin + Q, R = @constinferred qr_compact(t) + @test isisometry(Q) + @test Q * R ≈ t + end + @testset "qr_null" begin + N = @constinferred qr_null(t) + @test isisometry(N) + @test norm(N' * t) < 100 * eps(norm(t)) + end + @testset "lq_full" begin + L, Q = @constinferred lq_full(t) + @test isisometry(Q; side=:right) + @test L * Q ≈ t + end + @testset "lq_compact" begin + L, Q = @constinferred lq_compact(t) + @test isisometry(Q; side=:right) + @test L * Q ≈ t + end + @testset "lq_null" begin + Nᴴ = @constinferred lq_null(t) + @test isisometry(Nᴴ; side=:right) + @test norm(t * Nᴴ') < 100 * eps(norm(t)) + end @testset "leftorth with $alg" for alg in (TensorKit.LAPACK_HouseholderQR(), TensorKit.LAPACK_HouseholderQR(; From e909f38f35ea16c2da8c8574db3722c5d57e2fae Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Sep 2025 09:26:48 -0400 Subject: [PATCH 083/126] Bump MatrixAlgebraKit version --- Project.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 3d71eca49..9c0d55b7f 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.3.1" +MatrixAlgebraKit = "0.3.2" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" @@ -63,6 +63,3 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"] - -[sources] -MatrixAlgebraKit = {url="https://github.com/QuantumKitHub/MatrixAlgebraKit.jl", rev="ksh/copyfix"} From 049ec9a6553747169a96089c33e39020ef1edb49 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 18 Sep 2025 09:10:46 +0200 Subject: [PATCH 084/126] fix stackoverflow --- src/tensors/diagonal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 268f7f24f..7ab300ca0 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -99,7 +99,7 @@ function Base.copy!(t::AbstractTensorMap, d::DiagonalTensorMap) end return t end -TensorMap(d::DiagonalTensorMap) = copy!(similar(d), d) +TensorMap(d::DiagonalTensorMap) = copy!(similar(d, scalartype(d), space(d)), d) Base.convert(::Type{TensorMap}, d::DiagonalTensorMap) = TensorMap(d) function Base.convert(D::Type{<:DiagonalTensorMap}, d::DiagonalTensorMap) From 657f4fc57e89bedfd73c2978ecac72e7f39beaa2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 18 Sep 2025 10:07:19 +0200 Subject: [PATCH 085/126] improve BigFloat support --- src/tensors/vectorinterface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/vectorinterface.jl b/src/tensors/vectorinterface.jl index 92faa23be..f9eb0b6d1 100644 --- a/src/tensors/vectorinterface.jl +++ b/src/tensors/vectorinterface.jl @@ -68,7 +68,7 @@ function VectorInterface.add(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number, β::Number) space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))")) T = VectorInterface.promote_add(ty, tx, α, β) - return add!(scale!(similar(ty, T), ty, β), tx, α) + return add!(scale!(zerovector(ty, T), ty, β), tx, α) end function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number, β::Number) From a32f3ae9c310c528cf19af66c50feb9ce9ed2a3d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 20 Sep 2025 02:35:25 -0400 Subject: [PATCH 086/126] Last fixes --- Project.toml | 1 - src/tensors/factorizations/factorizations.jl | 56 -------------------- 2 files changed, 57 deletions(-) diff --git a/Project.toml b/Project.toml index 9c0d55b7f..66906c968 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,6 @@ OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" ScopedValues = "1.3.0" -SparseArrays = "1" Strided = "2" TensorKitSectors = "0.1.4, 0.2" TensorOperations = "5.1" diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 2c78310b6..460e91c63 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -62,66 +62,10 @@ end #------------------------------------------------------------------------------------------ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} -# DiagonalTensorMap -# ----------------- -function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...) - @assert alg isa Union{QR,QL} - return one(d), d # TODO: this is only correct for `alg = QR()` or `alg = QL()` -end -function rightorth!(d::DiagonalTensorMap; alg=LQ(), kwargs...) - @assert alg isa Union{LQ,RQ} - return d, one(d) # TODO: this is only correct for `alg = LQ()` or `alg = RQ()` -end -leftnull!(d::DiagonalTensorMap; kwargs...) = leftnull!(TensorMap(d); kwargs...) -rightnull!(d::DiagonalTensorMap; kwargs...) = rightnull!(TensorMap(d); kwargs...) - -function tsvd!(d::DiagonalTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD()) - return _tsvd!(d, alg, trunc, p) -end - -# helper function -function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD}) - InnerProductStyle(d) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) - I = sectortype(d) - dims = SectorDict{I,Int}() - generator = Base.Iterators.map(blocks(d)) do (c, b) - lb = length(b.diag) - U = zerovector!(similar(b.diag, lb, lb)) - V = zerovector!(similar(b.diag, lb, lb)) - p = sortperm(b.diag; by=abs, rev=true) - for (i, pi) in enumerate(p) - U[pi, i] = safesign(b.diag[pi]) - V[i, pi] = 1 - end - Σ = abs.(view(b.diag, p)) - dims[c] = lb - return c => (U, Σ, V) - end - SVDdata = SectorDict(generator) - return SVDdata, dims -end - -eig!(d::DiagonalTensorMap) = d, one(d) -eigh!(d::DiagonalTensorMap{<:Real}) = d, one(d) -eigh!(d::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(real(d.data), d.domain), one(d) - -function LinearAlgebra.svdvals(d::DiagonalTensorMap) - return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d)) -end -function LinearAlgebra.eigvals(d::DiagonalTensorMap) - return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d)) -end - -function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2) - return LinearAlgebra.cond(Diagonal(d.data), p) -end - #------------------------------# # LinearAlgebra overloads #------------------------------# -#LinearAlgebra.svdvals(t::AbstractTensorMap) = diagview(svd_vals(t)) LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t)) -#LinearAlgebra.eigvals(t::AbstractTensorMap) = diagview(eig_vals(t)) LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = diagview(eig_vals!(t)) #--------------------------------------------------# From 7176114b81ed84a9b4b4d8758699be5cd3cbd137 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 20 Sep 2025 10:31:01 -0400 Subject: [PATCH 087/126] Fix format --- .../factorizations/matrixalgebrakit.jl | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index c47cfb548..2f595909e 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -81,7 +81,8 @@ end const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} const _T_USVᴴ_diag = Tuple{<:AbstractTensorMap,<:DiagonalTensorMap,<:AbstractTensorMap} -function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ, ::AbstractAlgorithm) +function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ, + ::AbstractAlgorithm) # scalartype checks @check_scalar U t @check_scalar S t real @@ -97,7 +98,8 @@ function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T return nothing end -function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ_diag, ::AbstractAlgorithm) +function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ_diag, + ::AbstractAlgorithm) # scalartype checks @check_scalar U t @check_scalar S t real @@ -112,7 +114,8 @@ function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ): return nothing end -function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict, ::AbstractAlgorithm) +function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict, + ::AbstractAlgorithm) @check_scalar S t real V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) @check_space(S, V_cod ← V_dom) @@ -165,7 +168,8 @@ end # ------------------------ const _T_DV = Tuple{<:DiagonalTensorMap,<:AbstractTensorMap} -function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV, ::AbstractAlgorithm) +function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV, + ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -181,7 +185,8 @@ function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV, return nothing end -function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV, ::AbstractAlgorithm) +function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV, + ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -222,7 +227,8 @@ function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTens return nothing end -function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, ::AbstractAlgorithm) +function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, + ::AbstractAlgorithm) @check_scalar D t complex V_D = fuse(domain(t)) @check_space(D, V_D ← V_D) @@ -283,7 +289,8 @@ end # ---------------- const _T_QR = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR, ::AbstractAlgorithm) +function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR, + ::AbstractAlgorithm) # scalartype checks @check_scalar Q t @check_scalar R t @@ -296,7 +303,8 @@ function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR, :: return nothing end -function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR, ::AbstractAlgorithm) +function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR, + ::AbstractAlgorithm) # scalartype checks @check_scalar Q t @check_scalar R t @@ -309,7 +317,8 @@ function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR, return nothing end -function check_input(::typeof(qr_null!), t::AbstractTensorMap, N::AbstractTensorMap, ::AbstractAlgorithm) +function check_input(::typeof(qr_null!), t::AbstractTensorMap, N::AbstractTensorMap, + ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -346,7 +355,8 @@ end # ---------------- const _T_LQ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ, ::AbstractAlgorithm) +function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ, + ::AbstractAlgorithm) # scalartype checks @check_scalar L t @check_scalar Q t @@ -359,7 +369,8 @@ function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ, :: return nothing end -function check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)::_T_LQ, ::AbstractAlgorithm) +function check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)::_T_LQ, + ::AbstractAlgorithm) # scalartype checks @check_scalar L t @check_scalar Q t @@ -410,7 +421,8 @@ end const _T_WP = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} const _T_PWᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP, ::AbstractAlgorithm) +function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP, + ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) @@ -425,7 +437,8 @@ function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP, return nothing end -function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, (W, P)::_T_WP, ::AbstractAlgorithm) +function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, (W, P)::_T_WP, + ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) @@ -447,7 +460,8 @@ function initialize_output(::typeof(left_polar!), t::AbstractTensorMap, ::Abstra return W, P end -function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, ::AbstractAlgorithm) +function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, + ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) @@ -462,7 +476,8 @@ function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T return nothing end -function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, ::AbstractAlgorithm) +function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, + ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) @@ -500,7 +515,8 @@ end const _T_VC = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} const _T_CVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC, ::AbstractAlgorithm) +function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC, + ::AbstractAlgorithm) # scalartype checks @check_scalar V t isnothing(C) || @check_scalar C t @@ -513,7 +529,8 @@ function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC, return nothing end -function check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)::_T_CVᴴ, ::AbstractAlgorithm) +function check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)::_T_CVᴴ, + ::AbstractAlgorithm) # scalartype checks isnothing(C) || @check_scalar C t @check_scalar Vᴴ t From 38441ea45a0f6e0ddb46de03af3579ba3e2448cd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 21 Sep 2025 09:52:17 -0400 Subject: [PATCH 088/126] Add a few more Diagonal tests for coverage --- src/TensorKit.jl | 2 ++ test/diagonal.jl | 27 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 89fc6de02..b09d43512 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -76,6 +76,8 @@ export leftorth, rightorth, leftnull, rightnull, qr_full, qr_compact, qr_null, lq_full, lq_compact, lq_null, qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!, tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, + eigh_full!, eigh_full, eig_full!, eig_full, eigh_vals!, eigh_vals, + eig_vals!, eig_vals, isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! diff --git a/test/diagonal.jl b/test/diagonal.jl index 688e933f9..de0b19f44 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -188,7 +188,9 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @timedtestset "Factorization" begin for T in (Float32, ComplexF64) t = DiagonalTensorMap(rand(T, reduceddim(V)), V) - @testset "eig" begin + @testset "eig/eigh" begin + D, W = @constinferred eig_full(t) + @test t * W ≈ W * D D, W = @constinferred eig(t) @test t * W ≈ W * D t2 = t + t' @@ -196,6 +198,9 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), VdV2 = V2' * V2 @test VdV2 ≈ one(VdV2) @test t2 * V2 ≈ V2 * D2 + + D3 = @constinferred eigh_vals(t2) + @test D2 ≈ D3 @test rank(D) ≈ rank(t) @test cond(D) ≈ cond(t) @@ -231,6 +236,26 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @test isposdef(L) end end + @testset "qr_full with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) + Q, R = @constinferred qr_full(t; alg=alg) + @test isisometry(Q) + @test Q * R ≈ t + end + @testset "qr_compact with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) + Q, R = @constinferred qr_compact(t; alg=alg) + @test isisometry(Q) + @test Q * R ≈ t + end + @testset "lq_full with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) + L, Q = @constinferred lq_full(t; alg=alg) + @test isisometry(Q; side=:right) + @test L * Q ≈ t + end + @testset "lq_compact with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) + L, Q = @constinferred lq_compact(t; alg=alg) + @test isisometry(Q; side=:right) + @test L * Q ≈ t + end @testset "tsvd with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) U, S, Vᴴ = @constinferred tsvd(t; alg=alg) UdU = U' * U From e493d51620e7837c309534becccd4a14a2ec606d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 21 Sep 2025 10:29:13 -0400 Subject: [PATCH 089/126] Fix format again --- test/diagonal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/diagonal.jl b/test/diagonal.jl index de0b19f44..68f6cedc6 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -198,7 +198,7 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), VdV2 = V2' * V2 @test VdV2 ≈ one(VdV2) @test t2 * V2 ≈ V2 * D2 - + D3 = @constinferred eigh_vals(t2) @test D2 ≈ D3 From 8a97299a7a3185baade07d30e5d692b228f08c55 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 10:35:06 -0400 Subject: [PATCH 090/126] retain `AdjointTensorMap` in factorizations --- src/tensors/factorizations/adjoint.jl | 40 ++++++++++++++++++++------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 49f36777a..3ceb06359 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -1,21 +1,30 @@ # AdjointTensorMap # ---------------- +# map algorithms to their adjoint counterpart +# TODO: this probably belongs in MatrixAlgebraKit +_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.positive, alg.blocksize) +_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.positive, alg.blocksize) +_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.positive, alg.blocksize) +_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.positive, alg.blocksize) +_adjoint(alg::PolarViaSVD) = PolarViaSVD(_adjoint(alg.svdalg)) +_adjoint(alg::AbstractAlgorithm) = alg + # 1-arg functions function initialize_output(::typeof(left_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) - return adjoint(initialize_output(right_null!, adjoint(t), alg)) + return adjoint(initialize_output(right_null!, adjoint(t), _adjoint(alg))) end function initialize_output(::typeof(right_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) - return adjoint(initialize_output(left_null!, adjoint(t), alg)) + return adjoint(initialize_output(left_null!, adjoint(t), _adjoint(alg))) end function left_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) - right_null!(adjoint(t), adjoint(N), alg) + right_null!(adjoint(t), adjoint(N), _adjoint(alg)) return N end function right_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) - left_null!(adjoint(t), adjoint(N), alg) + left_null!(adjoint(t), adjoint(N), _adjoint(alg)) return N end @@ -29,40 +38,51 @@ end # 2-arg functions for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!), (:lq_full!, :lq_compact!, :right_polar!, :right_orth!)) + @eval function copy_input(::typeof($left_f!), t::AdjointTensorMap) + return adjoint(copy_input($right_f!, adjoint(t))) + end + @eval function copy_input(::typeof($right_f!), t::AdjointTensorMap) + return adjoint(copy_input($left_f!, adjoint(t))) + end + @eval function initialize_output(::typeof($left_f!), t::AdjointTensorMap, alg::AbstractAlgorithm) - return reverse(adjoint.(initialize_output($right_f!, adjoint(t), alg))) + return reverse(adjoint.(initialize_output($right_f!, adjoint(t), _adjoint(alg)))) end @eval function initialize_output(::typeof($right_f!), t::AdjointTensorMap, alg::AbstractAlgorithm) - return reverse(adjoint.(initialize_output($left_f!, adjoint(t), alg))) + return reverse(adjoint.(initialize_output($left_f!, adjoint(t), _adjoint(alg)))) end @eval function $left_f!(t::AdjointTensorMap, F::Tuple{AdjointTensorMap,AdjointTensorMap}, alg::AbstractAlgorithm) - $right_f!(adjoint(t), reverse(adjoint.(F)), alg) + $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end @eval function $right_f!(t::AdjointTensorMap, F::Tuple{AdjointTensorMap,AdjointTensorMap}, alg::AbstractAlgorithm) - $left_f!(adjoint(t), reverse(adjoint.(F)), alg) + $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end end # 3-arg functions for f! in (:svd_full!, :svd_compact!, :svd_trunc!) + @eval function copy_input(::typeof($f!), t::AdjointTensorMap) + return adjoint(copy_input($f!, adjoint(t))) + end + @eval function initialize_output(::typeof($f!), t::AdjointTensorMap, alg::AbstractAlgorithm) - return reverse(adjoint.(initialize_output($f!, adjoint(t), alg))) + return reverse(adjoint.(initialize_output($f!, adjoint(t), _adjoint(alg)))) end _TS = f! === :svd_full! ? :AdjointTensorMap : DiagonalTensorMap @eval function $f!(t::AdjointTensorMap, F::Tuple{AdjointTensorMap,$_TS,AdjointTensorMap}, alg::AbstractAlgorithm) - $f!(adjoint(t), reverse(adjoint.(F)), alg) + $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end end From 6ff08350b47a5fa388071a8689f7934b5254f54d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 13:32:00 -0400 Subject: [PATCH 091/126] improve deprecation --- Project.toml | 2 +- src/TensorKit.jl | 1 - src/auxiliary/linalg.jl | 43 ---- src/auxiliary/random.jl | 2 +- src/tensors/factorizations/deprecations.jl | 127 +++++++++++ src/tensors/factorizations/factorizations.jl | 7 +- src/tensors/factorizations/interface.jl | 220 ------------------- src/tensors/factorizations/truncation.jl | 2 - 8 files changed, 133 insertions(+), 271 deletions(-) diff --git a/Project.toml b/Project.toml index 66906c968..1986e2963 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorKit" uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" authors = ["Jutho Haegeman"] -version = "0.14.11" +version = "0.15.0-DEV" [deps] LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index b09d43512..6c4f86b04 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -148,7 +148,6 @@ include("auxiliary/auxiliary.jl") include("auxiliary/caches.jl") include("auxiliary/dicts.jl") include("auxiliary/iterators.jl") -include("auxiliary/linalg.jl") include("auxiliary/random.jl") #-------------------------------------------------------------------- diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl index 4a3bf9b15..c39a6b435 100644 --- a/src/auxiliary/linalg.jl +++ b/src/auxiliary/linalg.jl @@ -2,46 +2,3 @@ #------------------------------------------------------ set_num_blas_threads(n::Integer) = LinearAlgebra.BLAS.set_num_threads(n) get_num_blas_threads() = LinearAlgebra.BLAS.get_num_threads() - -# Factorization algorithms -#-------------------------- -abstract type FactorizationAlgorithm end -abstract type OrthogonalFactorizationAlgorithm <: FactorizationAlgorithm end - -struct QRpos <: OrthogonalFactorizationAlgorithm -end -struct QR <: OrthogonalFactorizationAlgorithm -end -struct QL <: OrthogonalFactorizationAlgorithm -end -struct QLpos <: OrthogonalFactorizationAlgorithm -end -struct LQ <: OrthogonalFactorizationAlgorithm -end -struct LQpos <: OrthogonalFactorizationAlgorithm -end -struct RQ <: OrthogonalFactorizationAlgorithm -end -struct RQpos <: OrthogonalFactorizationAlgorithm -end -struct SDD <: OrthogonalFactorizationAlgorithm # lapack's default divide and conquer algorithm -end -struct SVD <: OrthogonalFactorizationAlgorithm -end -struct Polar <: OrthogonalFactorizationAlgorithm -end - -Base.adjoint(::QRpos) = LQpos() -Base.adjoint(::QR) = LQ() -Base.adjoint(::LQpos) = QRpos() -Base.adjoint(::LQ) = QR() - -Base.adjoint(::QLpos) = RQpos() -Base.adjoint(::QL) = RQ() -Base.adjoint(::RQpos) = QLpos() -Base.adjoint(::RQ) = QL() - -Base.adjoint(alg::Union{SVD,SDD,Polar}) = alg - -const OFA = OrthogonalFactorizationAlgorithm -const SVDAlg = Union{SVD,SDD} diff --git a/src/auxiliary/random.jl b/src/auxiliary/random.jl index 3289cdc3c..fe57f585e 100644 --- a/src/auxiliary/random.jl +++ b/src/auxiliary/random.jl @@ -20,6 +20,6 @@ function randisometry!(rng::Random.AbstractRNG, A::AbstractMatrix) dims = size(A) dims[1] >= dims[2] || throw(DimensionMismatch("cannot create isometric matrix with dimensions $dims; isometry needs to be tall or square")) - Q, = leftorth!(Random.randn!(rng, A); alg=QRpos()) + Q, = qr_compact!(Random.randn!(rng, A); positive=true) return copy!(A, Q) end diff --git a/src/tensors/factorizations/deprecations.jl b/src/tensors/factorizations/deprecations.jl index 8b1378917..c610602e9 100644 --- a/src/tensors/factorizations/deprecations.jl +++ b/src/tensors/factorizations/deprecations.jl @@ -1 +1,128 @@ +# Factorization structs +@deprecate QR() LAPACK_HouseholderQR() +@deprecate QRpos() LAPACK_HouseholderQR(; positive=true) +@deprecate QL() LAPACK_HouseholderQL() +@deprecate QLpos() LAPACK_HouseholderQL(; positive=true) + +@deprecate LQ() LAPACK_HouseholderLQ() +@deprecate LQpos() LAPACK_HouseholderLQ(; positive=true) + +@deprecate RQ() LAPACK_HouseholderRQ() +@deprecate RQpos() LAPACK_HouseholderRQ(; positive=true) + +@deprecate SDD() LAPACK_DivideAndConquer() +@deprecate SVD() LAPACK_QRIteration() + +@deprecate Polar() PolarViaSVD(LAPACK_DivideAndConquer()) + +# truncations +const TruncationScheme = TruncationStrategy +@deprecate truncdim(d::Int) truncrank(d) +@deprecate truncbelow(ϵ::Real) trunctol(ϵ) + +# factorizations +# -------------- +# orthogonalization +@deprecate(leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + leftorth!(permutedcopy_oftype(t, factorization_scalartype(leftorth, t), p); + kwargs...)) +@deprecate(rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + rightorth!(permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p); + kwargs...)) +function leftorth(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) + return left_orth(t; kwargs...) +end +function leftorth!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftorth!` is no longer supported, use `left_orth!` instead", :leftorth!) + return left_orth!(t; kwargs...) +end +function rightorth(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) + return right_orth(t; kwargs...) +end +function rightorth!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightorth!` is no longer supported, use `right_orth!` instead", + :rightorth!) + return right_orth!(t; kwargs...) +end + +# nullspaces +@deprecate(leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + leftnull!(permutedcopy_oftype(t, factorization_scalartype(leftnull, t), p); + kwargs...)) +@deprecate(rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + rightnull!(permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p); + kwargs...)) +function leftnull(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) + return left_null(t; kwargs...) +end +function leftnull!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`left_null!` is no longer supported, use `left_null!` instead", + :leftnull!) + return left_null!(t; kwargs...) +end +function rightnull(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) + return right_null(t; kwargs...) +end +function rightnull!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightnull!` is no longer supported, use `right_null!` instead", + :rightnull!) + return right_null!(t; kwargs...) +end + +# eigen values +@deprecate(eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + eig!(permutedcopy_oftype(t, factorisation_scalartype(eig, t), p); kwargs...)) +@deprecate(eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + eigh!(permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p); kwargs...)) +@deprecate(eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + eigen!(permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p); kwargs...)) +function eig(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eig` is no longer supported, use `eig_full` or `eig_trunc` instead", + :eig) + return haskey(kwargs, :trunc) ? eig_trunc(t; kwargs...) : eig_full(t; kwargs...) +end +function eig!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eig!` is no longer supported, use `eig_full!` or `eig_trunc!` instead", + :eig!) + return haskey(kwargs, :trunc) ? eig_trunc!(t; kwargs...) : eig_full!(t; kwargs...) +end +function eigh(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eigh` is no longer supported, use `eigh_full` or `eigh_trunc` instead", + :eigh) + return haskey(kwargs, :trunc) ? eigh_trunc(t; kwargs...) : eigh_full(t; kwargs...) +end +function eigh!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eigh!` is no longer supported, use `eigh_full!` or `eigh_trunc!` instead", + :eigh!) + return haskey(kwargs, :trunc) ? eigh_trunc!(t; kwargs...) : eigh_full!(t; kwargs...) +end + +# singular values +_drop_p(; p=nothing, kwargs...) = kwargs +@deprecate(tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + tsvd!(permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p); kwargs...)) +function tsvd(t::AbstractTensorMap; kwargs...) + Base.depwarn("`tsvd` is no longer supported, use `svd_compact`, `svd_full` or `svd_trunc` instead", + :tsvd) + if haskey(kwargs, :p) + Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", + :tsvd) + kwargs = _drop_p(; kwargs...) + end + return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : svd_compact(t; kwargs...) +end +function tsvd!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`tsvd!` is no longer supported, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", + :tsvd!) + if haskey(kwargs, :p) + Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", + :tsvd!) + kwargs = _drop_p(; kwargs...) + end + return haskey(kwargs, :trunc) ? svd_trunc!(t; kwargs...) : svd_compact!(t; kwargs...) +end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 460e91c63..6ddcf79a0 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -15,7 +15,7 @@ export copy_oftype, permutedcopy_oftype, one! export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD using ..TensorKit -using ..TensorKit: AdjointTensorMap, SectorDict, OFA, blocktype, foreachblock, one! +using ..TensorKit: AdjointTensorMap, SectorDict, blocktype, foreachblock, one! using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals! import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian @@ -27,7 +27,8 @@ using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrateg NoTruncation, TruncationKeepAbove, TruncationKeepBelow, TruncationIntersection, TruncationKeepFiltered, PolarViaSVD, LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR, - LAPACK_HouseholderLQ, DiagonalAlgorithm + LAPACK_HouseholderLQ, LAPACK_HouseholderQL, LAPACK_HouseholderRQ, + DiagonalAlgorithm import MatrixAlgebraKit: default_algorithm, copy_input, check_input, initialize_output, qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, @@ -42,7 +43,7 @@ import MatrixAlgebraKit: default_algorithm, include("utility.jl") include("interface.jl") -include("implementations.jl") +# include("implementations.jl") include("matrixalgebrakit.jl") include("truncation.jl") include("deprecations.jl") diff --git a/src/tensors/factorizations/interface.jl b/src/tensors/factorizations/interface.jl index 22c3c75d5..fba7195a4 100644 --- a/src/tensors/factorizations/interface.jl +++ b/src/tensors/factorizations/interface.jl @@ -1,208 +1,3 @@ -@doc """ - tsvd(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; - trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD()) - -> U, S, V, ϵ - tsvd!(t::AbstractTensorMap, trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD()) - -> U, S, V, ϵ - -Compute the (possibly truncated) singular value decomposition such that -`norm(permute(t, (leftind, rightind)) - U * S * V) ≈ ϵ`, where `ϵ` thus represents the truncation error. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `tsvd!(t, trunc = notrunc(), p = 2)`. - -A truncation parameter `trunc` can be specified for the new internal dimension, in which -case a truncated singular value decomposition will be computed. Choices are: -* `notrunc()`: no truncation (default); -* `truncerr(η::Real)`: truncates such that the p-norm of the truncated singular values is - smaller than `η`; -* `truncdim(χ::Int)`: truncates such that the equivalent total dimension of the internal - vector space is no larger than `χ`; -* `truncspace(V)`: truncates such that the dimension of the internal vector space is - smaller than that of `V` in any sector. -* `truncbelow(η::Real)`: truncates such that every singular value is larger then `η` ; - -Truncation options can also be combined using `&`, i.e. `truncbelow(η) & truncdim(χ)` will -choose the truncation space such that every singular value is larger than `η`, and the -equivalent total dimension of the internal vector space is no larger than `χ`. - -The method `tsvd` also returns the truncation error `ϵ`, computed as the `p` norm of the -singular values that were truncated. - -The keyword `alg` can be equal to `SVD()` or `SDD()`, corresponding to the underlying LAPACK -algorithm that computes the decomposition (`_gesvd` or `_gesdd`). - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)` -is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`. -""" tsvd, tsvd! - -@doc """ - eig(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> D, V - eig!(t::AbstractTensorMap; kwargs...) -> D, V - -Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. -The function `eig` assumes that the linear map is not hermitian and returns type stable -complex valued `D` and `V` tensors for both real and complex valued `t`. See `eigh` for -hermitian linear maps - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `eig!(t)`. Note that the permuted tensor on -which `eig!` is called should have equal domain and codomain, as otherwise the eigenvalue -decomposition is meaningless and cannot satisfy -``` -permute(t, (leftind, rightind)) * V = V * D -``` - -Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense -matrices. See the corresponding documentation for more information. - -See also `eigen` and `eigh`. -""" eig - -@doc """ - eigh(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> D, V - eigh!(t::AbstractTensorMap; kwargs...) -> D, V - -Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. -The function `eigh` assumes that the linear map is hermitian and `D` and `V` tensors with -the same `scalartype` as `t`. See `eig` and `eigen` for non-hermitian tensors. Hermiticity -requires that the tensor acts on inner product spaces, and the current implementation -requires `InnerProductStyle(t) === EuclideanInnerProduct()`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `eigh!(t)`. Note that the permuted tensor on -which `eigh!` is called should have equal domain and codomain, as otherwise the eigenvalue -decomposition is meaningless and cannot satisfy -``` -permute(t, (leftind, rightind)) * V = V * D -``` - -See also `eigen` and `eig`. -""" eigh, eigh! - -@doc """ - leftorth(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple; - alg::OrthogonalFactorizationAlgorithm = QRpos()) -> Q, R - -Create orthonormal basis `Q` for indices in `leftind`, and remainder `R` such that -`permute(t, (leftind, rightind)) = Q*R`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `leftorth!(t, alg = QRpos())`. - -Different algorithms are available, namely `QR()`, `QRpos()`, `SVD()` and `Polar()`. `QR()` -and `QRpos()` use a standard QR decomposition, producing an upper triangular matrix `R`. -`Polar()` produces a Hermitian and positive semidefinite `R`. `QRpos()` corrects the -standard QR decomposition such that the diagonal elements of `R` are positive. Only -`QRpos()` and `Polar()` are unique (no residual freedom) so that they always return the same -result for the same input tensor `t`. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`leftorth(!)` is currently only implemented for - `InnerProductStyle(t) === EuclideanInnerProduct()`. -""" leftorth, leftorth! - -@doc """ - rightorth(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; - alg::OrthogonalFactorizationAlgorithm = LQpos()) -> L, Q - rightorth!(t::AbstractTensorMap; alg) -> L, Q - -Create orthonormal basis `Q` for indices in `rightind`, and remainder `L` such that -`permute(t, (leftind, rightind)) = L*Q`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `rightorth!(t, alg = LQpos())`. - -Different algorithms are available, namely `LQ()`, `LQpos()`, `RQ()`, `RQpos()`, `SVD()` and -`Polar()`. `LQ()` and `LQpos()` produce a lower triangular matrix `L` and are computed using -a QR decomposition of the transpose. `RQ()` and `RQpos()` produce an upper triangular -remainder `L` and only works if the total left dimension is smaller than or equal to the -total right dimension. `LQpos()` and `RQpos()` add an additional correction such that the -diagonal elements of `L` are positive. `Polar()` produces a Hermitian and positive -semidefinite `L`. Only `LQpos()`, `RQpos()` and `Polar()` are unique (no residual freedom) -so that they always return the same result for the same input tensor `t`. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`rightorth(!)` is currently only implemented for -`InnerProductStyle(t) === EuclideanInnerProduct()`. -""" rightorth, rightorth! - -@doc """ - leftnull(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; - alg::OrthogonalFactorizationAlgorithm = QRpos()) -> N - leftnull!(t::AbstractTensorMap; alg) -> N - -Create orthonormal basis for the orthogonal complement of the support of the indices in -`leftind`, such that `N' * permute(t, (leftind, rightind)) = 0`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `leftnull!(t, alg = QRpos())`. - -Different algorithms are available, namely `QR()` (or equivalently, `QRpos()`), `SVD()` and -`SDD()`. The first assumes that the matrix is full rank and requires `iszero(atol)` and -`iszero(rtol)`. With `SVD()` and `SDD()`, `rightnull` will use the corresponding singular -value decomposition, and one can specify an absolute or relative tolerance for which -singular values are to be considered zero, where `max(atol, norm(t)*rtol)` is used as upper -bound. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`leftnull(!)` is currently only implemented for -`InnerProductStyle(t) === EuclideanInnerProduct()`. -""" leftnull, leftnull! - -@doc """ - rightnull(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; - alg::OrthogonalFactorizationAlgorithm = LQ(), - atol::Real = 0.0, - rtol::Real = eps(real(float(one(scalartype(t)))))*iszero(atol)) -> N - rightnull!(t::AbstractTensorMap; alg, atol, rtol) - -Create orthonormal basis for the orthogonal complement of the support of the indices in -`rightind`, such that `permute(t, (leftind, rightind))*N' = 0`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `rightnull!(t, alg = LQpos())`. - -Different algorithms are available, namely `LQ()` (or equivalently, `LQpos`), `SVD()` and -`SDD()`. The first assumes that the matrix is full rank and requires `iszero(atol)` and -`iszero(rtol)`. With `SVD()` and `SDD()`, `rightnull` will use the corresponding singular -value decomposition, and one can specify an absolute or relative tolerance for which -singular values are to be considered zero, where `max(atol, norm(t)*rtol)` is used as upper -bound. - -Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and -`rightnull(!)` is currently only implemented for -`InnerProductStyle(t) === EuclideanInnerProduct()`. -""" rightnull, rightnull! - -@doc """ - eigen(t::AbstractTensorMap, [(leftind, rightind)::Index2Tuple]; kwargs...) -> D, V - eigen!(t::AbstractTensorMap; kwargs...) -> D, V - -Compute eigenvalue factorization of tensor `t` as linear map from `rightind` to `leftind`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in `t` -to be destroyed/overwritten, by using `eigen!(t)`. Note that the permuted tensor on which -`eigen!` is called should have equal domain and codomain, as otherwise the eigenvalue -decomposition is meaningless and cannot satisfy -``` -permute(t, (leftind, rightind)) * V = V * D -``` - -Accepts the same keyword arguments `scale` and `permute` as `eigen` of dense -matrices. See the corresponding documentation for more information. - -See also [`eig(!)`](@ref eig) and [`eigh(!)`](@ref) -""" eigen(::AbstractTensorMap), eigen!(::AbstractTensorMap) - @doc """ isposdef(t::AbstractTensor, [(leftind, rightind)::Index2Tuple]) -> ::Bool @@ -215,21 +10,6 @@ which `isposdef!` is called should have equal domain and codomain, as otherwise meaningless. """ isposdef(::AbstractTensorMap), isposdef!(::AbstractTensorMap) -for f in - (:tsvd, :eig, :eigh, :eigen, :leftorth, :rightorth, :left_polar, :right_polar, - :leftnull, - :rightnull, :isposdef) - f! = Symbol(f, :!) - @eval function $f(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - tcopy = permutedcopy_oftype(t, factorisation_scalartype($f, t), p) - return $f!(tcopy; kwargs...) - end - @eval function $f(t::AbstractTensorMap; kwargs...) - tcopy = copy_oftype(t, factorisation_scalartype($f, t)) - return $f!(tcopy; kwargs...) - end -end - function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...) tcopy = copy_oftype(t, factorisation_scalartype(eigen, t)) return LinearAlgebra.eigvals!(tcopy; kwargs...) diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index d553c5c93..91032d56e 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -4,8 +4,6 @@ notrunc() = NoTruncation() # deprecate const TruncationScheme = TruncationStrategy -@deprecate truncdim(d::Int) truncrank(d) -@deprecate truncbelow(ϵ::Real, add_back::Int=0) trunctol(ϵ) # TODO: add this to MatrixAlgebraKit struct TruncationError{T<:Real} <: TruncationStrategy From 529410e4d7fce08c349188a0202afc10a1f9e2c7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 13:57:42 -0400 Subject: [PATCH 092/126] cleanup --- src/tensors/factorizations/adjoint.jl | 21 +- src/tensors/factorizations/diagonal.jl | 34 ++-- src/tensors/factorizations/factorizations.jl | 14 +- src/tensors/factorizations/implementations.jl | 192 ------------------ src/tensors/factorizations/interface.jl | 21 -- .../factorizations/matrixalgebrakit.jl | 142 ++++++++----- src/tensors/factorizations/truncation.jl | 14 +- 7 files changed, 143 insertions(+), 295 deletions(-) delete mode 100644 src/tensors/factorizations/implementations.jl delete mode 100644 src/tensors/factorizations/interface.jl diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 3ceb06359..df7189c54 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -19,11 +19,11 @@ function initialize_output(::typeof(right_null!), t::AdjointTensorMap, return adjoint(initialize_output(left_null!, adjoint(t), _adjoint(alg))) end -function left_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) +function left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) right_null!(adjoint(t), adjoint(N), _adjoint(alg)) return N end -function right_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) +function right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) left_null!(adjoint(t), adjoint(N), _adjoint(alg)) return N end @@ -54,15 +54,11 @@ for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_ort return reverse(adjoint.(initialize_output($left_f!, adjoint(t), _adjoint(alg)))) end - @eval function $left_f!(t::AdjointTensorMap, - F::Tuple{AdjointTensorMap,AdjointTensorMap}, - alg::AbstractAlgorithm) + @eval function $left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end - @eval function $right_f!(t::AdjointTensorMap, - F::Tuple{AdjointTensorMap,AdjointTensorMap}, - alg::AbstractAlgorithm) + @eval function $right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end @@ -78,10 +74,7 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) alg::AbstractAlgorithm) return reverse(adjoint.(initialize_output($f!, adjoint(t), _adjoint(alg)))) end - _TS = f! === :svd_full! ? :AdjointTensorMap : DiagonalTensorMap - @eval function $f!(t::AdjointTensorMap, - F::Tuple{AdjointTensorMap,$_TS,AdjointTensorMap}, - alg::AbstractAlgorithm) + @eval function $f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end @@ -92,9 +85,7 @@ function initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap, return initialize_output(svd_compact!, t, alg.alg) end # to fix ambiguity -function svd_trunc!(t::AdjointTensorMap, - USVᴴ::Tuple{AdjointTensorMap,DiagonalTensorMap,AdjointTensorMap}, - alg::TruncatedAlgorithm) +function svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) return truncate!(svd_trunc!, USVᴴ′, alg.trunc) end diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index eb34aab48..1f4445641 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -60,8 +60,9 @@ for f! in end for f! in (:qr_full!, :qr_compact!) - @eval function check_input(::typeof($f!), d::AbstractTensorMap, (Q, R)::_T_QR, + @eval function check_input(::typeof($f!), d::AbstractTensorMap, QR, ::DiagonalAlgorithm) + Q, R = QR @assert d isa DiagonalTensorMap @assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap @check_scalar Q d @@ -74,8 +75,9 @@ for f! in (:qr_full!, :qr_compact!) end for f! in (:lq_full!, :lq_compact!) - @eval function check_input(::typeof($f!), d::AbstractTensorMap, (L, Q)::_T_LQ, + @eval function check_input(::typeof($f!), d::AbstractTensorMap, LQ, ::DiagonalAlgorithm) + L, Q = LQ @assert d isa DiagonalTensorMap @assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap @check_scalar Q d @@ -103,11 +105,15 @@ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) end end -function check_input(::typeof(eig_full!), t::DiagonalTensorMap, (D, V)::_T_DV, - ::DiagonalAlgorithm) +function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + D, V = DV + + @assert D isa DiagonalTensorMap + @assert V isa AbstractTensorMap + # scalartype checks @check_scalar D t @check_scalar V t @@ -119,11 +125,15 @@ function check_input(::typeof(eig_full!), t::DiagonalTensorMap, (D, V)::_T_DV, return nothing end -function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, (D, V)::_T_DV, - ::DiagonalAlgorithm) +function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + D, V = DV + + @assert D isa DiagonalTensorMap + @assert V isa AbstractTensorMap + # scalartype checks @check_scalar D t real @check_scalar V t @@ -135,22 +145,22 @@ function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, (D, V)::_T_DV, return nothing end -function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, - ::DiagonalAlgorithm) +function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) + @assert D isa DiagonalTensorMap @check_scalar D t @check_space D space(t) return nothing end -function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, - ::DiagonalAlgorithm) +function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) + @assert D isa DiagonalTensorMap @check_scalar D t real @check_space D space(t) return nothing end -function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, - ::DiagonalAlgorithm) +function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) + @assert D isa DiagonalTensorMap @check_scalar D t real @check_space D space(t) return nothing diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 6ddcf79a0..a0efe8066 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -42,8 +42,6 @@ import MatrixAlgebraKit: default_algorithm, diagview, isisometry include("utility.jl") -include("interface.jl") -# include("implementations.jl") include("matrixalgebrakit.jl") include("truncation.jl") include("deprecations.jl") @@ -66,9 +64,19 @@ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} #------------------------------# # LinearAlgebra overloads #------------------------------# -LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t)) + +function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...) + tcopy = copy_oftype(t, factorisation_scalartype(eigen, t)) + return LinearAlgebra.eigvals!(tcopy; kwargs...) +end LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = diagview(eig_vals!(t)) +function LinearAlgebra.svdvals(t::AbstractTensorMap) + tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t)) + return LinearAlgebra.svdvals!(tcopy) +end +LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t)) + #--------------------------------------------------# # Checks for hermiticity and positive definiteness # #--------------------------------------------------# diff --git a/src/tensors/factorizations/implementations.jl b/src/tensors/factorizations/implementations.jl deleted file mode 100644 index a25ed22d2..000000000 --- a/src/tensors/factorizations/implementations.jl +++ /dev/null @@ -1,192 +0,0 @@ -_kindof(::Union{SVD,SDD}) = :svd -_kindof(::Union{QR,QRpos}) = :qr -_kindof(::Union{LQ,LQpos}) = :lq -_kindof(::Polar) = :polar - -_kindof(::LAPACK_HouseholderQR) = :qr -_kindof(::LAPACK_HouseholderLQ) = :lq -_kindof(::LAPACK_SVDAlgorithm) = :svd -_kindof(::PolarViaSVD) = :polar -_kindof(::DiagonalAlgorithm) = :svd - -leftorth!(t; alg=nothing, kwargs...) = _leftorth!(t, alg; kwargs...) - -function _leftorth!(t::AbstractTensorMap, alg::Nothing, ; kwargs...) - return isempty(kwargs) ? left_orth!(t) : left_orth!(t; trunc=(; kwargs...)) -end -function _leftorth!(t::AbstractTensorMap, alg::Union{QL,QLpos}; kwargs...) - trunc = isempty(kwargs) ? nothing : (; kwargs...) - - if alg == QL() || alg == QLpos() - _reverse!(t; dims=2) - Q, R = left_orth!(t; kind=:qr, alg_qr=(; positive=alg == QLpos()), trunc) - _reverse!(Q; dims=2) - _reverse!(R) - return Q, R - end -end -function _leftorth!(t, alg::Union{OFA,AbstractAlgorithm}; kwargs...) - trunc = isempty(kwargs) ? nothing : (; kwargs...) - - alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :leftorth!) - - kind = _kindof(alg) - if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : - alg === DiagonalAlgorithm() ? alg : - alg === SVD() ? LAPACK_QRIteration() : - alg === SDD() ? LAPACK_DivideAndConquer() : - throw(ArgumentError(lazy"Unknown algorithm $alg")) - return left_orth!(t; kind, alg_svd, trunc) - elseif kind == :qr - alg_qr = (; positive=(alg == QRpos())) - return left_orth!(t; kind, alg_qr, trunc) - elseif kind == :polar - return left_orth!(t; kind, trunc) - else - throw(ArgumentError(lazy"Invalid algorithm: $alg")) - end -end -# fallback to MatrixAlgebraKit version -_leftorth!(t, alg; kwargs...) = left_orth!(t, alg; kwargs...) - -function leftnull!(t::AbstractTensorMap; - alg::Union{LAPACK_HouseholderQR,LAPACK_QRIteration, - LAPACK_DivideAndConquer,PolarViaSVD,QR,QRpos,SVD,SDD,Nothing}=nothing, - kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:leftnull!) - trunc = isempty(kwargs) ? nothing : (; kwargs...) - alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :leftnull!) - - isnothing(alg) && return left_null!(t; trunc) - - kind = _kindof(alg) - if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : - alg === SVD() ? LAPACK_QRIteration() : - alg === SDD() ? LAPACK_DivideAndConquer() : - throw(ArgumentError(lazy"Unknown algorithm $alg")) - return left_null!(t; kind, alg_svd, trunc) - elseif kind == :qr - alg_qr = (; positive=(alg == QRpos())) - return left_null!(t; kind, alg_qr, trunc) - else - throw(ArgumentError(lazy"Invalid `leftnull!` algorithm: $alg")) - end -end - -function rightorth!(t::AbstractTensorMap; - alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, - LAPACK_DivideAndConquer,DiagonalAlgorithm,PolarViaSVD,LQ, - LQpos,RQ,RQpos,SVD, - SDD,Polar,Nothing}=nothing, kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightorth!) - trunc = isempty(kwargs) ? nothing : (; kwargs...) - - alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :rightorth!) - - isnothing(alg) && return right_orth!(t; trunc) - - if alg == RQ() || alg == RQpos() - _reverse!(t; dims=1) - L, Q = right_orth!(t; kind=:lq, alg_lq=(; positive=alg == RQpos()), trunc) - _reverse!(Q; dims=1) - _reverse!(L) - return L, Q - end - - kind = _kindof(alg) - if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : - alg === DiagonalAlgorithm() ? alg : - alg === SVD() ? LAPACK_QRIteration() : - alg === SDD() ? LAPACK_DivideAndConquer() : - throw(ArgumentError(lazy"Unknown algorithm $alg")) - return right_orth!(t; kind, alg_svd, trunc) - elseif kind == :lq - alg_lq = (; positive=(alg == LQpos())) - return right_orth!(t; kind, alg_lq, trunc) - elseif kind == :polar - return right_orth!(t; kind, trunc) - else - throw(ArgumentError(lazy"Invalid `rightorth!` algorithm: $alg")) - end -end - -function rightnull!(t::AbstractTensorMap; - alg::Union{LAPACK_HouseholderLQ,LAPACK_QRIteration, - LAPACK_DivideAndConquer,PolarViaSVD,LQ,LQpos,SVD,SDD, - Nothing}=nothing, kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || - throw_invalid_innerproduct(:rightnull!) - trunc = isempty(kwargs) ? nothing : (; kwargs...) - - alg isa OFA && Base.depwarn(lazy"$alg is deprecated", :rightnull!) - - isnothing(alg) && return right_null!(t; trunc) - - kind = _kindof(alg) - if kind == :svd - alg_svd = alg === LAPACK_QRIteration() ? alg : - alg === LAPACK_DivideAndConquer() ? alg : - alg === SVD() ? LAPACK_QRIteration() : - alg === SDD() ? LAPACK_DivideAndConquer() : - throw(ArgumentError(lazy"Unknown algorithm $alg")) - return right_null!(t; kind, alg_svd, trunc) - elseif kind == :lq - alg_lq = (; positive=(alg == LQpos())) - return right_null!(t; kind, alg_lq, trunc) - else - throw(ArgumentError(lazy"Invalid `rightnull!` algorithm: $alg")) - end -end - -# Eigenvalue decomposition -# ------------------------ -function eigh!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) - if trunc == notrunc() - return eigh_full!(t; kwargs...) - else - return eigh_trunc!(t; trunc, kwargs...) - end -end - -function eig!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eig!) - - if trunc == notrunc() - return eig_full!(t; kwargs...) - else - return eig_trunc!(t; trunc, kwargs...) - end -end - -function eigen!(t::AbstractTensorMap; kwargs...) - return ishermitian(t) ? eigh!(t; kwargs...) : eig!(t; kwargs...) -end - -# Singular value decomposition -# ---------------------------- -function tsvd!(t::AbstractTensorMap; trunc=notrunc(), p=nothing, alg=nothing, kwargs...) - InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!) - isnothing(p) || Base.depwarn("p is no longer supported", :tsvd!) - - if alg isa OFA - Base.depwarn(lazy"$alg is deprecated", :tsvd!) - alg = alg === SVD() ? LAPACK_QRIteration() : - alg === SDD() ? LAPACK_DivideAndConquer() : - throw(ArgumentError(lazy"Unknown algorithm $alg")) - end - - if trunc == notrunc() - return svd_compact!(t; alg, kwargs...) - else - return svd_trunc!(t; trunc, alg, kwargs...) - end -end diff --git a/src/tensors/factorizations/interface.jl b/src/tensors/factorizations/interface.jl deleted file mode 100644 index fba7195a4..000000000 --- a/src/tensors/factorizations/interface.jl +++ /dev/null @@ -1,21 +0,0 @@ -@doc """ - isposdef(t::AbstractTensor, [(leftind, rightind)::Index2Tuple]) -> ::Bool - -Test whether a tensor `t` is positive definite as linear map from `rightind` to `leftind`. - -If `leftind` and `rightind` are not specified, the current partition of left and right -indices of `t` is used. In that case, less memory is allocated if one allows the data in -`t` to be destroyed/overwritten, by using `isposdef!(t)`. Note that the permuted tensor on -which `isposdef!` is called should have equal domain and codomain, as otherwise it is -meaningless. -""" isposdef(::AbstractTensorMap), isposdef!(::AbstractTensorMap) - -function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...) - tcopy = copy_oftype(t, factorisation_scalartype(eigen, t)) - return LinearAlgebra.eigvals!(tcopy; kwargs...) -end - -function LinearAlgebra.svdvals(t::AbstractTensorMap) - tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t)) - return LinearAlgebra.svdvals!(tcopy) -end diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 2f595909e..1f6662afc 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -19,7 +19,7 @@ function _select_truncation(::typeof(left_null!), ::AbstractTensorMap, trunc::Na end # Generic Implementations -# ----------------------_ +# ----------------------- for f! in (:qr_compact!, :qr_full!, :lq_compact!, :lq_full!, :eig_full!, :eigh_full!, @@ -78,11 +78,14 @@ end # Singular value decomposition # ---------------------------- -const _T_USVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap,<:AbstractTensorMap} -const _T_USVᴴ_diag = Tuple{<:AbstractTensorMap,<:DiagonalTensorMap,<:AbstractTensorMap} +function check_input(::typeof(svd_full!), t::AbstractTensorMap, USVᴴ, ::AbstractAlgorithm) + U, S, Vᴴ = USVᴴ + + # type checks + @assert U isa AbstractTensorMap + @assert S isa AbstractTensorMap + @assert Vᴴ isa AbstractTensorMap -function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ, - ::AbstractAlgorithm) # scalartype checks @check_scalar U t @check_scalar S t real @@ -98,8 +101,15 @@ function check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)::_T return nothing end -function check_input(::typeof(svd_compact!), t::AbstractTensorMap, (U, S, Vᴴ)::_T_USVᴴ_diag, +function check_input(::typeof(svd_compact!), t::AbstractTensorMap, USVᴴ, ::AbstractAlgorithm) + U, S, Vᴴ = USVᴴ + + # type checks + @assert U isa AbstractTensorMap + @assert S isa DiagonalTensorMap + @assert Vᴴ isa AbstractTensorMap + # scalartype checks @check_scalar U t @check_scalar S t real @@ -122,9 +132,9 @@ function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict, return nothing end -function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, - ::AbstractAlgorithm) +function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t real + @assert D isa DiagonalTensorMap V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) @check_space(D, V_cod ← V_dom) return nothing @@ -166,13 +176,16 @@ end # Eigenvalue decomposition # ------------------------ -const _T_DV = Tuple{<:DiagonalTensorMap,<:AbstractTensorMap} - -function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV, - ::AbstractAlgorithm) +function check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + D, V = DV + + # type checks + @assert D isa DiagonalTensorMap + @assert V isa AbstractTensorMap + # scalartype checks @check_scalar D t real @check_scalar V t @@ -185,11 +198,16 @@ function check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)::_T_DV, return nothing end -function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV, - ::AbstractAlgorithm) +function check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + D, V = DV + + # type checks + @assert D isa DiagonalTensorMap + @assert V isa AbstractTensorMap + # scalartype checks @check_scalar D t complex @check_scalar V t complex @@ -202,11 +220,16 @@ function check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)::_T_DV, return nothing end -function check_input(::typeof(eig_full!), t::DiagonalTensorMap, (D, V)::_T_DV, - ::AbstractAlgorithm) +function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) + D, V = DV + + # type checks + @assert D isa DiagonalTensorMap + @assert V isa AbstractTensorMap + # scalartype checks @check_scalar D t @check_scalar V t @@ -219,17 +242,17 @@ function check_input(::typeof(eig_full!), t::DiagonalTensorMap, (D, V)::_T_DV, return nothing end -function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, - ::AbstractAlgorithm) +function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t real + @assert D isa DiagonalTensorMap V_D = fuse(domain(t)) @check_space(D, V_D ← V_D) return nothing end -function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D::DiagonalTensorMap, - ::AbstractAlgorithm) +function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t complex + @assert D isa DiagonalTensorMap V_D = fuse(domain(t)) @check_space(D, V_D ← V_D) return nothing @@ -287,10 +310,13 @@ end # QR decomposition # ---------------- -const _T_QR = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} +function check_input(::typeof(qr_full!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) + Q, R = QR + + # type checks + @assert Q isa AbstractTensorMap + @assert R isa AbstractTensorMap -function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR, - ::AbstractAlgorithm) # scalartype checks @check_scalar Q t @check_scalar R t @@ -303,8 +329,13 @@ function check_input(::typeof(qr_full!), t::AbstractTensorMap, (Q, R)::_T_QR, return nothing end -function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR, - ::AbstractAlgorithm) +function check_input(::typeof(qr_compact!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) + Q, R = QR + + # type checks + @assert Q isa AbstractTensorMap + @assert R isa AbstractTensorMap + # scalartype checks @check_scalar Q t @check_scalar R t @@ -317,8 +348,7 @@ function check_input(::typeof(qr_compact!), t::AbstractTensorMap, (Q, R)::_T_QR, return nothing end -function check_input(::typeof(qr_null!), t::AbstractTensorMap, N::AbstractTensorMap, - ::AbstractAlgorithm) +function check_input(::typeof(qr_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -353,10 +383,13 @@ end # LQ decomposition # ---------------- -const _T_LQ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} +function check_input(::typeof(lq_full!), t::AbstractTensorMap, LQ, ::AbstractAlgorithm) + L, Q = LQ + + # type checks + @assert L isa AbstractTensorMap + @assert R isa AbstractTensorMap -function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ, - ::AbstractAlgorithm) # scalartype checks @check_scalar L t @check_scalar Q t @@ -369,8 +402,13 @@ function check_input(::typeof(lq_full!), t::AbstractTensorMap, (L, Q)::_T_LQ, return nothing end -function check_input(::typeof(lq_compact!), t::AbstractTensorMap, (L, Q)::_T_LQ, - ::AbstractAlgorithm) +function check_input(::typeof(lq_compact!), t::AbstractTensorMap, LQ, ::AbstractAlgorithm) + L, Q = LQ + + # type checks + @assert L isa AbstractTensorMap + @assert R isa AbstractTensorMap + # scalartype checks @check_scalar L t @check_scalar Q t @@ -418,14 +456,14 @@ end # Polar decomposition # ------------------- -const _T_WP = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -const _T_PWᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} - -function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP, - ::AbstractAlgorithm) +function check_input(::typeof(left_polar!), t::AbstractTensorMap, WP, ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) + W, P = WP + @assert W isa AbstractTensorMap + @assert P isa AbstractTensorMap + # scalartype checks @check_scalar W t @check_scalar P t @@ -437,11 +475,15 @@ function check_input(::typeof(left_polar!), t::AbstractTensorMap, (W, P)::_T_WP, return nothing end -function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, (W, P)::_T_WP, +function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, WP, ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) + W, P = WP + @assert W isa AbstractTensorMap + @assert P isa AbstractTensorMap + # scalartype checks @check_scalar W t @check_scalar P t @@ -460,11 +502,14 @@ function initialize_output(::typeof(left_polar!), t::AbstractTensorMap, ::Abstra return W, P end -function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, - ::AbstractAlgorithm) +function check_input(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) + P, Wᴴ = PWᴴ + @assert P isa AbstractTensorMap + @assert Wᴴ isa AbstractTensorMap + # scalartype checks @check_scalar P t @check_scalar Wᴴ t @@ -476,11 +521,15 @@ function check_input(::typeof(right_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T return nothing end -function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, (P, Wᴴ)::_T_PWᴴ, +function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, PWᴴ, ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) + P, Wᴴ = PWᴴ + @assert P isa AbstractTensorMap + @assert Wᴴ isa AbstractTensorMap + # scalartype checks @check_scalar P t @check_scalar Wᴴ t @@ -512,11 +561,9 @@ end # Orthogonalization # ----------------- -const _T_VC = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} -const _T_CVᴴ = Tuple{<:AbstractTensorMap,<:AbstractTensorMap} +function check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, ::AbstractAlgorithm) + V, C = VC -function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC, - ::AbstractAlgorithm) # scalartype checks @check_scalar V t isnothing(C) || @check_scalar C t @@ -529,8 +576,9 @@ function check_input(::typeof(left_orth!), t::AbstractTensorMap, (V, C)::_T_VC, return nothing end -function check_input(::typeof(right_orth!), t::AbstractTensorMap, (C, Vᴴ)::_T_CVᴴ, - ::AbstractAlgorithm) +function check_input(::typeof(right_orth!), t::AbstractTensorMap, CVᴴ, ::AbstractAlgorithm) + C, Vᴴ = CVᴴ + # scalartype checks isnothing(C) || @check_scalar C t @check_scalar Vᴴ t diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 91032d56e..022c404d4 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -19,7 +19,9 @@ truncspace(space::ElementarySpace) = TruncationSpace(space) # Truncation # ---------- -function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ)::_T_USVᴴ, strategy::TruncationStrategy) +function truncate!(::typeof(svd_trunc!), + (U, S, Vᴴ)::Tuple{AbstractTensorMap,AbstractTensorMap,AbstractTensorMap}, + strategy::TruncationStrategy) ind = findtruncated_sorted(diagview(S), strategy) V_truncated = spacetype(S)(c => length(I) for (c, I) in ind) @@ -48,8 +50,7 @@ function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ)::_T_USVᴴ, strategy::Trun end function truncate!(::typeof(left_null!), - (U, S)::Tuple{<:AbstractTensorMap, - <:AbstractTensorMap}, + (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, strategy::MatrixAlgebraKit.TruncationStrategy) extended_S = SectorDict(c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) @@ -63,7 +64,9 @@ function truncate!(::typeof(left_null!), return Ũ end -function truncate!(::typeof(eigh_trunc!), (D, V)::_T_DV, strategy::TruncationStrategy) +function truncate!(::typeof(eigh_trunc!), + (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, + strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) @@ -83,7 +86,8 @@ function truncate!(::typeof(eigh_trunc!), (D, V)::_T_DV, strategy::TruncationStr return D̃, Ṽ end -function truncate!(::typeof(eig_trunc!), (D, V)::_T_DV, strategy::TruncationStrategy) +function truncate!(::typeof(eig_trunc!), (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, + strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) From a4f6d694f1eb2c93c888fdc3349621276cb29787 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 18:34:48 -0400 Subject: [PATCH 093/126] rework tests and more fixes and cleanup --- src/TensorKit.jl | 20 +- src/auxiliary/deprecate.jl | 25 - src/tensors/factorizations/deprecations.jl | 93 +++- src/tensors/factorizations/factorizations.jl | 10 + .../factorizations/matrixalgebrakit.jl | 18 +- src/tensors/factorizations/truncation.jl | 61 +-- src/tensors/factorizations/utility.jl | 2 +- src/tensors/linalg.jl | 9 +- test/factorizations.jl | 463 ++++++++++-------- 9 files changed, 382 insertions(+), 319 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 6c4f86b04..88467fa4b 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -70,28 +70,30 @@ export inner, dot, norm, normalize, normalize!, tr # factorizations export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby! -export leftorth, rightorth, leftnull, rightnull, - leftorth!, rightorth!, leftnull!, rightnull!, +export left_orth, right_orth, left_null, right_null, + left_orth!, right_orth!, left_null!, right_null!, left_polar, left_polar!, right_polar, right_polar!, qr_full, qr_compact, qr_null, lq_full, lq_compact, lq_null, qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!, - tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!, - eigh_full!, eigh_full, eig_full!, eig_full, eigh_vals!, eigh_vals, - eig_vals!, eig_vals, + svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc, + exp, exp!, + eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eig_full!, eig_full, eig_trunc!, + eig_trunc, + eigh_vals!, eigh_vals, eig_vals!, eig_vals, isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond +# deprecate: +export eig, eig!, eigh, eigh!, eigen, eigen!, tsvd, tsvd!, leftorth, leftorth!, rightorth, + rightorth!, leftnull, leftnull!, rightnull, rightnull! export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! export catdomain, catcodomain, absorb, absorb! -export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQpos, - SVD, SDD, Polar - # tensor operations export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor export scalar, add!, contract! # truncation schemes -export notrunc, truncerr, truncdim, truncspace, truncbelow +export notrunc, truncerr, truncrank, truncspace, trunctol # cache management export empty_globalcaches! diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index b235cbd7c..98f44fba7 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -1,30 +1,5 @@ import Base: transpose -#! format: off -# Base.@deprecate(permute(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false), -# permute(t, (p1, p2); copy=copy)) -# Base.@deprecate(transpose(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple; copy::Bool=false), -# transpose(t, (p1, p2); copy=copy)) -# Base.@deprecate(braid(t::AbstractTensorMap, p1::IndexTuple, p2::IndexTuple, levels; copy::Bool=false), -# braid(t, (p1, p2), levels; copy=copy)) - -# Base.@deprecate(tsvd(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# tsvd(t, (p₁, p₂); kwargs...)) -# Base.@deprecate(leftorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# leftorth(t, (p₁, p₂); kwargs...)) -# Base.@deprecate(rightorth(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# rightorth(t, (p₁, p₂); kwargs...)) -# Base.@deprecate(leftnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# leftnull(t, (p₁, p₂); kwargs...)) -# Base.@deprecate(rightnull(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# rightnull(t, (p₁, p₂); kwargs...)) -# Base.@deprecate(LinearAlgebra.eigen(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# LinearAlgebra.eigen(t, (p₁, p₂); kwargs...), false) -# Base.@deprecate(eig(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# eig(t, (p₁, p₂); kwargs...)) -# Base.@deprecate(eigh(t::AbstractTensorMap, p₁::IndexTuple, p₂::IndexTuple; kwargs...), -# eigh(t, (p₁, p₂); kwargs...)) - for f in (:rand, :randn, :zeros, :ones) @eval begin Base.@deprecate TensorMap(::typeof($f), T::Type, P::HomSpace) $f(T, P) diff --git a/src/tensors/factorizations/deprecations.jl b/src/tensors/factorizations/deprecations.jl index c610602e9..a52383225 100644 --- a/src/tensors/factorizations/deprecations.jl +++ b/src/tensors/factorizations/deprecations.jl @@ -23,58 +23,97 @@ const TruncationScheme = TruncationStrategy # factorizations # -------------- +_kindof(::LAPACK_HouseholderQR) = :qr +_kindof(::LAPACK_HouseholderLQ) = :lq +_kindof(::LAPACK_SVDAlgorithm) = :svd +_kindof(::PolarViaSVD) = :polar + +_drop_alg(; alg=nothing, kwargs...) = kwargs +_drop_p(; p=nothing, kwargs...) = kwargs + # orthogonalization -@deprecate(leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - leftorth!(permutedcopy_oftype(t, factorization_scalartype(leftorth, t), p); - kwargs...)) -@deprecate(rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - rightorth!(permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p); - kwargs...)) -function leftorth(t::AbstractTensorMap; kwargs...) +export leftorth, leftorth!, rightorth, rightorth! +function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) - return left_orth(t; kwargs...) + return leftorth!(permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p); + kwargs...) end -function leftorth!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`leftorth!` is no longer supported, use `left_orth!` instead", :leftorth!) - return left_orth!(t; kwargs...) +function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) + return rightorth!(permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p); + kwargs...) +end +function leftorth(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) + return leftorth!(copy_oftype(t, factorisation_scalartype(leftorth, t)); kwargs...) end function rightorth(t::AbstractTensorMap; kwargs...) Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) - return right_orth(t; kwargs...) + return rightorth!(copy_oftype(t, factorisation_scalartype(rightorth, t)); kwargs...) +end +function leftorth!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftorth!` is no longer supported, use `left_orth!` instead", :leftorth!) + haskey(kwargs, :alg) || return left_orth!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return left_orth!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :qr && return left_orth!(t; kind, alg_qr=alg, _drop_alg(; kwargs...)...) + kind === :polar && return left_orth!(t; kind, alg_polar=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid leftorth kind")) end function rightorth!(t::AbstractTensorMap; kwargs...) Base.depwarn("`rightorth!` is no longer supported, use `right_orth!` instead", :rightorth!) - return right_orth!(t; kwargs...) + haskey(kwargs, :alg) || return right_orth!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return right_orth!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :lq && return right_orth!(t; kind, alg_lq=alg, _drop_alg(; kwargs...)...) + kind === :polar && return right_orth!(t; kind, alg_polar=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid rightorth kind")) end # nullspaces -@deprecate(leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - leftnull!(permutedcopy_oftype(t, factorization_scalartype(leftnull, t), p); - kwargs...)) -@deprecate(rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - rightnull!(permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p); - kwargs...)) +export leftnull, leftnull!, rightnull, rightnull! function leftnull(t::AbstractTensorMap; kwargs...) Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) - return left_null(t; kwargs...) + return leftnull!(copy_oftype(t, factorisation_scalartype(leftnull, t)); kwargs...) end -function leftnull!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`left_null!` is no longer supported, use `left_null!` instead", - :leftnull!) - return left_null!(t; kwargs...) +function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) + return leftnull!(permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p); kwargs...) end function rightnull(t::AbstractTensorMap; kwargs...) Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) - return right_null(t; kwargs...) + return rightnull!(copy_oftype(t, factorisation_scalartype(rightnull, t)); kwargs...) +end +function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) + return rightnull!(permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p); kwargs...) +end +function leftnull!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`left_null!` is no longer supported, use `left_null!` instead", + :leftnull!) + haskey(kwargs, :alg) || return left_null!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return left_null!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :qr && return left_null!(t; kind, alg_qr=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid leftnull kind")) end function rightnull!(t::AbstractTensorMap; kwargs...) Base.depwarn("`rightnull!` is no longer supported, use `right_null!` instead", :rightnull!) - return right_null!(t; kwargs...) + haskey(kwargs, :alg) || return right_null!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return right_null!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :lq && return right_null!(t; kind, alg_lq=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid rightnull kind")) end # eigen values +export eig, eig!, eigh, eigh!, eigen, eigen! @deprecate(eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...), eig!(permutedcopy_oftype(t, factorisation_scalartype(eig, t), p); kwargs...)) @deprecate(eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...), @@ -103,7 +142,7 @@ function eigh!(t::AbstractTensorMap; kwargs...) end # singular values -_drop_p(; p=nothing, kwargs...) = kwargs +export tsvd, tsvd! @deprecate(tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...), tsvd!(permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p); kwargs...)) function tsvd(t::AbstractTensorMap; kwargs...) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index a0efe8066..7718a44b9 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -65,6 +65,13 @@ const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} # LinearAlgebra overloads #------------------------------# +function eigen(t::AbstractTensorMap; kwargs...) + return ishermitian(t) ? eigh_full(t; kwargs...) : eig_full(t; kwargs...) +end +function eigen!(t::AbstractTensorMap; kwargs...) + return ishermitian(t) ? eigh_full!(t; kwargs...) : eig_full!(t; kwargs...) +end + function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...) tcopy = copy_oftype(t, factorisation_scalartype(eigen, t)) return LinearAlgebra.eigvals!(tcopy; kwargs...) @@ -89,6 +96,9 @@ function LinearAlgebra.ishermitian(t::AbstractTensorMap) return true end +function LinearAlgebra.isposdef(t::AbstractTensorMap) + return isposdef!(copy_oftype(t, factorisation_scalartype(isposdef, t))) +end function LinearAlgebra.isposdef!(t::AbstractTensorMap) domain(t) == codomain(t) || throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 1f6662afc..7704dd5f8 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -1,13 +1,17 @@ # Algorithm selection # ------------------- -for f! in - [:svd_compact!, :svd_full!, :svd_trunc!, :svd_vals!, :qr_compact!, :qr_full!, :qr_null!, - :lq_compact!, :lq_full!, :lq_null!, :eig_full!, :eig_trunc!, :eig_vals!, :eigh_full!, - :eigh_trunc!, :eigh_vals!, :left_polar!, :right_polar!] +for f in + [:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, + :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, + :eigh_trunc, :eigh_vals, :left_polar, :right_polar] + f! = Symbol(f, :!) @eval function default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T<:AbstractTensorMap} return default_algorithm($f!, blocktype(T); kwargs...) end + @eval function copy_input(::typeof($f), t::AbstractTensorMap) + return copy_oftype(t, factorisation_scalartype($f, t)) + end end function _select_truncation(f, ::AbstractTensorMap, @@ -388,7 +392,7 @@ function check_input(::typeof(lq_full!), t::AbstractTensorMap, LQ, ::AbstractAlg # type checks @assert L isa AbstractTensorMap - @assert R isa AbstractTensorMap + @assert Q isa AbstractTensorMap # scalartype checks @check_scalar L t @@ -407,7 +411,7 @@ function check_input(::typeof(lq_compact!), t::AbstractTensorMap, LQ, ::Abstract # type checks @assert L isa AbstractTensorMap - @assert R isa AbstractTensorMap + @assert Q isa AbstractTensorMap # scalartype checks @check_scalar L t @@ -546,7 +550,7 @@ function initialize_output(::typeof(right_polar!), t::AbstractTensorMap, ::AbstractAlgorithm) P = similar(t, codomain(t) ← codomain(t)) Wᴴ = similar(t, space(t)) - return Wᴴ, P + return P, Wᴴ end # Needed to get algorithm selection to behave diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 022c404d4..708e5c252 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -64,48 +64,29 @@ function truncate!(::typeof(left_null!), return Ũ end -function truncate!(::typeof(eigh_trunc!), - (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, - strategy::TruncationStrategy) - ind = findtruncated(diagview(D), strategy) - V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) - - D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) - for (c, b) in blocks(D̃) - I = get(ind, c, nothing) - @assert !isnothing(I) - copy!(b.diag, @view(block(D, c).diag[I])) - end - - Ṽ = similar(V, V_truncated ← domain(V)) - for (c, b) in blocks(Ṽ) - I = get(ind, c, nothing) - @assert !isnothing(I) - copy!(b, @view(block(V, c)[I, :])) - end - - return D̃, Ṽ -end -function truncate!(::typeof(eig_trunc!), (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, - strategy::TruncationStrategy) - ind = findtruncated(diagview(D), strategy) - V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) +for f! in (:eig_trunc!, :eigh_trunc!) + @eval function truncate!(::typeof($f!), + (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, + strategy::TruncationStrategy) + ind = findtruncated(diagview(D), strategy) + V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) + + D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) + for (c, b) in blocks(D̃) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b.diag, @view(block(D, c).diag[I])) + end - D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) - for (c, b) in blocks(D̃) - I = get(ind, c, nothing) - @assert !isnothing(I) - copy!(b.diag, @view(block(D, c).diag[I])) - end + Ṽ = similar(V, codomain(V) ← V_truncated) + for (c, b) in blocks(Ṽ) + I = get(ind, c, nothing) + @assert !isnothing(I) + copy!(b, @view(block(V, c)[:, I])) + end - Ṽ = similar(V, V_truncated ← domain(V)) - for (c, b) in blocks(Ṽ) - I = get(ind, c, nothing) - @assert !isnothing(I) - copy!(b, @view(block(V, c)[I, :])) + return D̃, Ṽ end - - return D̃, Ṽ end # Find truncation @@ -201,7 +182,7 @@ function findtruncated(Sd::SectorDict, strategy::TruncationKeepSorted) truncdim[cmin] -= 1 totaldim -= dim(cmin) if totaldim < strategy.howmany - truncdim[cmin] += 1 + # truncdim[cmin] += 1 break end if truncdim[cmin] == 0 diff --git a/src/tensors/factorizations/utility.jl b/src/tensors/factorizations/utility.jl index e23fb8b73..874b944d3 100644 --- a/src/tensors/factorizations/utility.jl +++ b/src/tensors/factorizations/utility.jl @@ -16,7 +16,7 @@ function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2T return permute!(similar(t, T, permute(space(t), p)), t, p) end function copy_oftype(t::AbstractTensorMap, T::Type{<:Number}) - return copy!(similar(t, T), t) + return copy!(similar(t, T, space(t)), t) end function _reverse!(t::AbstractTensorMap; dims=:) diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 5afa3e031..c569cc58f 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -287,12 +287,13 @@ end _default_rtol(t) = eps(real(float(scalartype(t)))) * min(dim(domain(t)), dim(codomain(t))) -function LinearAlgebra.rank(t::AbstractTensorMap; atol::Real=0, - rtol::Real=atol > 0 ? 0 : _default_rtol(t)) - dim(t) == 0 && return 0 +function LinearAlgebra.rank(t::AbstractTensorMap; + atol::Real=0, rtol::Real=atol > 0 ? 0 : _default_rtol(t)) + init = dim(one(sectortype(t))) * 0 + dim(t) == 0 && return init S = LinearAlgebra.svdvals(t) tol = max(atol, rtol * maximum(first, values(S))) - return sum(cs -> dim(cs[1]) * count(>(tol), cs[2]), S) + return sum(((c, b),) -> dim(c) * count(>(tol), b), S; init) end function LinearAlgebra.cond(t::AbstractTensorMap, p::Real=2) diff --git a/test/factorizations.jl b/test/factorizations.jl index cf8521af6..d56bd8329 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -15,221 +15,272 @@ catch (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂) end +eltypes = (Float32, ComplexF64) for V in spacelist I = sectortype(first(V)) Istr = TensorKit.type_repr(I) println("---------------------------------------") println("Tensors with symmetry: $Istr") println("---------------------------------------") - @timedtestset "Tensors with symmetry: $Istr" verbose = true begin + @timedtestset "Factorizations with symmetry: $Istr" verbose = true begin V1, V2, V3, V4, V5 = V - @timedtestset "Factorization" begin - W = V1 ⊗ V2 - @testset for T in (Float32, ComplexF64) - # Test both a normal tensor and an adjoint one. - ts = (rand(T, W, W'), rand(T, W, W')', rand(T, V1, W'), rand(T, V1, W')') - @testset for t in ts - @testset "qr_full" begin - Q, R = @constinferred qr_full(t) - @test isisometry(Q) - @test Q * R ≈ t - end - @testset "qr_compact" begin - Q, R = @constinferred qr_compact(t) - @test isisometry(Q) - @test Q * R ≈ t - end - @testset "qr_null" begin - N = @constinferred qr_null(t) - @test isisometry(N) - @test norm(N' * t) < 100 * eps(norm(t)) - end - @testset "lq_full" begin - L, Q = @constinferred lq_full(t) - @test isisometry(Q; side=:right) - @test L * Q ≈ t - end - @testset "lq_compact" begin - L, Q = @constinferred lq_compact(t) - @test isisometry(Q; side=:right) - @test L * Q ≈ t - end - @testset "lq_null" begin - Nᴴ = @constinferred lq_null(t) - @test isisometry(Nᴴ; side=:right) - @test norm(t * Nᴴ') < 100 * eps(norm(t)) - end - @testset "leftorth with $alg" for alg in - (TensorKit.LAPACK_HouseholderQR(), - TensorKit.LAPACK_HouseholderQR(; - positive=true), - #TensorKit.QL(), - #TensorKit.QLpos(), - TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), - TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - (codomain(t) ≾ domain(t)) && alg isa TensorKit.PolarViaSVD && - continue - Q, R = @constinferred leftorth(t; alg=alg) - @test isisometry(Q) - @test Q * R ≈ t - end - @testset "leftnull with $alg" for alg in - (TensorKit.LAPACK_HouseholderQR(), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - N = @constinferred leftnull(t; alg=alg) - @test isisometry(N) - @test norm(N' * t) < 100 * eps(norm(t)) - end - @testset "rightorth with $alg" for alg in - (TensorKit.LAPACK_HouseholderLQ(), - TensorKit.LAPACK_HouseholderLQ(; - positive=true), - TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), - TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - (domain(t) ≾ codomain(t)) && alg isa TensorKit.PolarViaSVD && - continue - L, Q = @constinferred rightorth(t; alg=alg) - @test isisometry(Q; side=:right) - @test L * Q ≈ t - end - @testset "rightnull with $alg" for alg in - (TensorKit.LAPACK_HouseholderLQ(), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - M = @constinferred rightnull(t; alg=alg) - @test isisometry(M; side=:right) - @test norm(t * M') < 100 * eps(norm(t)) - end - @testset "tsvd with $alg" for alg in (TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - U, S, V = @constinferred tsvd(t; alg=alg) - @test isisometry(U) - @test isisometry(V; side=:right) - @test U * S * V ≈ t - - s = LinearAlgebra.svdvals(t) - s′ = LinearAlgebra.diag(S) - for (c, b) in s - @test b ≈ s′[c] - end - s = LinearAlgebra.svdvals(t') - s′ = LinearAlgebra.diag(S') - for (c, b) in s - @test b ≈ s′[c] - end - end - @testset "cond and rank" begin - d1 = dim(codomain(t)) - d2 = dim(domain(t)) - @test rank(t) == min(d1, d2) - M = leftnull(t) - @test rank(M) + rank(t) == d1 - t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2) - @test cond(t3) ≈ one(real(T)) - @test rank(t3) == dim(V1 ⊗ V2) - t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2) - t4 = (t4 + t4') / 2 - vals = LinearAlgebra.eigvals(t4) - λmax = maximum(s -> maximum(abs, s), values(vals)) - λmin = minimum(s -> minimum(abs, s), values(vals)) - @test cond(t4) ≈ λmax / λmin - vals = LinearAlgebra.eigvals(t4') - λmax = maximum(s -> maximum(abs, s), values(vals)) - λmin = minimum(s -> minimum(abs, s), values(vals)) - @test cond(t4') ≈ λmax / λmin - end - end - @testset "empty tensor" begin - t = randn(T, V1 ⊗ V2, zero(V1)) - @testset "leftorth with $alg" for alg in - (TensorKit.LAPACK_HouseholderQR(), - TensorKit.LAPACK_HouseholderQR(; - positive=true), - #TensorKit.QL(), TensorKit.QLpos(), - TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), - TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - Q, R = @constinferred leftorth(t; alg=alg) - @test Q == t - @test dim(Q) == dim(R) == 0 - end - @testset "leftnull with $alg" for alg in - (TensorKit.LAPACK_HouseholderQR(), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - N = @constinferred leftnull(t; alg=alg) - @test isunitary(N) - end - @testset "rightorth with $alg" for alg in - (TensorKit.LAPACK_HouseholderLQ(), - TensorKit.LAPACK_HouseholderLQ(; - positive=true), - TensorKit.PolarViaSVD(TensorKit.LAPACK_QRIteration()), - TensorKit.PolarViaSVD(TensorKit.LAPACK_DivideAndConquer()), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - L, Q = @constinferred rightorth(copy(t'); alg=alg) - @test Q == t' - @test dim(Q) == dim(L) == 0 - end - @testset "rightnull with $alg" for alg in - (TensorKit.LAPACK_HouseholderLQ(), - TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - M = @constinferred rightnull(copy(t'); alg=alg) - @test isunitary(M) - end - @testset "tsvd with $alg" for alg in (TensorKit.LAPACK_QRIteration(), - TensorKit.LAPACK_DivideAndConquer()) - U, S, V = @constinferred tsvd(t; alg=alg) - @test U == t - @test dim(U) == dim(S) == dim(V) - end - @testset "cond and rank" begin - @test rank(t) == 0 - W2 = zero(V1) * zero(V2) - t2 = rand(W2, W2) - @test rank(t2) == 0 - @test cond(t2) == 0.0 - end + W = V1 ⊗ V2 + + @testset "QR decomposition" begin + for T in eltypes, + t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)') + + Q, R = @constinferred qr_full(t) + @test Q * R ≈ t + @test isunitary(Q) + + Q, R = @constinferred qr_compact(t) + @test Q * R ≈ t + @test isisometry(Q) + + Q, R = @constinferred left_orth(t; kind=:qr) + @test Q * R ≈ t + @test isisometry(Q) + + N = @constinferred qr_null(t) + @test isisometry(N) + @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) + + N = @constinferred left_null(t; kind=:qr) + @test isisometry(N) + @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) + end + + # empty tensor + for T in eltypes + t = rand(T, V1 ⊗ V2, zero(V1)) + + Q, R = @constinferred qr_full(t) + @test Q * R ≈ t + @test isunitary(Q) + @test dim(R) == dim(t) == 0 + + Q, R = @constinferred qr_compact(t) + @test Q * R ≈ t + @test isisometry(Q) + @test dim(Q) == dim(R) == dim(t) + + Q, R = @constinferred left_orth(t; kind=:qr) + @test Q * R ≈ t + @test isisometry(Q) + @test dim(Q) == dim(R) == dim(t) + + N = @constinferred qr_null(t) + @test isunitary(N) + @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) + end + end + + @testset "LQ decomposition" begin + for T in eltypes, + t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)') + + L, Q = @constinferred lq_full(t) + @test L * Q ≈ t + @test isunitary(Q) + + L, Q = @constinferred lq_compact(t) + @test L * Q ≈ t + @test isisometry(Q; side=:right) + + L, Q = @constinferred right_orth(t; kind=:lq) + @test L * Q ≈ t + @test isisometry(Q; side=:right) + + Nᴴ = @constinferred lq_null(t) + @test isisometry(Nᴴ; side=:right) + @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) + end + + for T in eltypes + # empty tensor + t = rand(T, zero(V1), V1 ⊗ V2) + + L, Q = @constinferred lq_full(t) + @test L * Q ≈ t + @test isunitary(Q) + @test dim(L) == dim(t) == 0 + + L, Q = @constinferred lq_compact(t) + @test L * Q ≈ t + @test isisometry(Q; side=:right) + @test dim(Q) == dim(L) == dim(t) + + L, Q = @constinferred right_orth(t; kind=:lq) + @test L * Q ≈ t + @test isisometry(Q; side=:right) + @test dim(Q) == dim(L) == dim(t) + + Nᴴ = @constinferred lq_null(t) + @test isunitary(Nᴴ) + @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) + end + end + + @testset "Polar decomposition" begin + for T in eltypes, + t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)') + + @assert domain(t) ≾ codomain(t) + w, p = @constinferred left_polar(t) + @test w * p ≈ t + @test isisometry(w) + @test isposdef(p) + + w, p = @constinferred left_orth(t; kind=:polar) + @test w * p ≈ t + @test isisometry(w) + end + + for T in eltypes, + t in (rand(T, W, W), rand(T, W, W)', rand(T, V1, W), rand(T, W, V1)') + + @assert codomain(t) ≾ domain(t) + p, wᴴ = @constinferred right_polar(t) + @test p * wᴴ ≈ t + @test isisometry(wᴴ; side=:right) + @test isposdef(p) + + p, wᴴ = @constinferred right_orth(t; kind=:polar) + @test p * wᴴ ≈ t + @test isisometry(wᴴ; side=:right) + end + end + + @testset "SVD" begin + for T in eltypes, + t in (rand(T, W, W), rand(T, W, W)', + rand(T, W, V1), rand(T, V1, W), + rand(T, W, V1)', rand(T, V1, W)') + + u, s, vᴴ = @constinferred svd_full(t) + @test u * s * vᴴ ≈ t + @test isunitary(u) + @test isunitary(vᴴ) + + u, s, vᴴ = @constinferred svd_compact(t) + @test u * s * vᴴ ≈ t + @test isisometry(u) + @test isposdef(s) + @test isisometry(vᴴ; side=:right) + + s′ = LinearAlgebra.diag(s) + for (c, b) in LinearAlgebra.svdvals(t) + @test b ≈ s′[c] end - @testset "eig and isposdef" begin - t = rand(T, V1, V1) - D, V = eigen(t) - @test t * V ≈ V * D - - d = LinearAlgebra.eigvals(t; sortby=nothing) - d′ = LinearAlgebra.diag(D) - for (c, b) in d - @test b ≈ d′[c] - end - - # Somehow moving these test before the previous one gives rise to errors - # with T=Float32 on x86 platforms. Is this an OpenBLAS issue? - VdV = V' * V - VdV = (VdV + VdV') / 2 - @test isposdef(VdV) - - @test !isposdef(t) # unlikely for non-hermitian map - t2 = (t + t') - D, V = eigen(t2) - @test isisometry(V) - D̃, Ṽ = @constinferred eigh(t2) - @test D ≈ D̃ - @test V ≈ Ṽ - λ = minimum(minimum(real(LinearAlgebra.diag(b))) - for (c, b) in blocks(D)) - @test cond(Ṽ) ≈ one(real(T)) - @test isposdef(t2) == isposdef(λ) - @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) - @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) + + v, c = @constinferred left_orth(t; kind=:svd) + @test v * c ≈ t + @test isisometry(v) + + N = @constinferred left_null(t; kind=:svd) + @test isisometry(N) + @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) + + Nᴴ = @constinferred right_null(t; kind=:svd) + @test isisometry(Nᴴ; side=:right) + @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) + end + + # empty tensor + for T in eltypes, t in (rand(T, W, zero(V1)), rand(T, zero(V1), W)) + U, S, Vᴴ = @constinferred svd_full(t) + @test U * S * Vᴴ ≈ t + @test isunitary(U) + @test isunitary(Vᴴ) + + U, S, Vᴴ = @constinferred svd_compact(t) + @test U * S * Vᴴ ≈ t + @test dim(U) == dim(S) == dim(Vᴴ) == dim(t) == 0 + end + end + + @testset "Eigenvalue decomposition" begin + for T in eltypes, t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)') + d, v = @constinferred eig_full(t) + @test t * v ≈ v * d + + d′ = LinearAlgebra.diag(d) + for (c, b) in LinearAlgebra.eigvals(t) + @test sort(b; by=abs) ≈ sort(d′[c]; by=abs) end + + vdv = v' * v + vdv = (vdv + vdv') / 2 + @test @constinferred isposdef(vdv) + @test !isposdef(t) # unlikely for non-hermitian map + + d, v = @constinferred eig_trunc(t; trunc=truncrank(dim(domain(t)) ÷ 2)) + @test t * v ≈ v * d + @test dim(domain(d)) ≤ dim(domain(t)) ÷ 2 + + + t2 = (t + t') + D, V = eigen(t2) + @test isisometry(V) + D̃, Ṽ = @constinferred eigh(t2) + @test D ≈ D̃ + @test V ≈ Ṽ + λ = minimum(minimum(real(LinearAlgebra.diag(b))) + for (c, b) in blocks(D)) + @test cond(Ṽ) ≈ one(real(T)) + @test isposdef(t2) == isposdef(λ) + @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) + @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) + + add!(t, t') + + d, v = @constinferred eigh_full(t) + @test t * v ≈ v * d + @test isunitary(v) + + λ = minimum(minimum(real(LinearAlgebra.diag(b))) for (c, b) in blocks(d)) + @test cond(v) ≈ one(real(T)) + @test isposdef(t) == isposdef(λ) + @test isposdef(t - λ * one(t) + 0.1 * one(t)) + @test !isposdef(t - λ * one(t) - 0.1 * one(t)) + + d, v = @constinferred eigh_trunc(t; trunc=truncrank(dim(domain(t)) ÷ 2)) + @test t * v ≈ v * d + @test dim(domain(d)) ≤ dim(domain(t)) ÷ 2 + end + end + + @testset "Condition number and rank" begin + for T in eltypes, + t in (rand(T, W, W), rand(T, W, W)', + rand(T, W, V1), rand(T, V1, W), + rand(T, W, V1)', rand(T, V1, W)') + + d1, d2 = dim(codomain(t)), dim(domain(t)) + @test rank(t) == min(d1, d2) + M = left_null(t) + @test @constinferred(rank(M)) + rank(t) == d1 + Mᴴ = right_null(t) + @test rank(Mᴴ) + rank(t) == d2 + end + for T in eltypes + u = unitary(T, V1 ⊗ V2, V1 ⊗ V2) + @test @constinferred(cond(u)) ≈ one(real(T)) + @test @constinferred(rank(u)) == dim(V1 ⊗ V2) + + t = rand(T, zero(V1), W) + @test rank(t) == 0 + t2 = rand(T, zero(V1) * zero(V2), zero(V1) * zero(V2)) + @test rank(t2) == 0 + @test cond(t2) == 0.0 + end + for T in eltypes, t in (rand(T, W, W), rand(T, W, W)') + add!(t, t') + vals = @constinferred LinearAlgebra.eigvals(t) + λmax = maximum(s -> maximum(abs, s), values(vals)) + λmin = minimum(s -> minimum(abs, s), values(vals)) + @test cond(t) ≈ λmax / λmin end end end From 20b15a79f4f28ddf0a81dbbdd7d5740a5cc7c4eb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 18:47:19 -0400 Subject: [PATCH 094/126] move deprecations to single location --- src/auxiliary/deprecate.jl | 156 +++++++++++++++++ src/tensors/factorizations/deprecations.jl | 167 ------------------- src/tensors/factorizations/factorizations.jl | 1 - 3 files changed, 156 insertions(+), 168 deletions(-) delete mode 100644 src/tensors/factorizations/deprecations.jl diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index 98f44fba7..7661c2cf7 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -33,4 +33,160 @@ Base.@deprecate EuclideanProduct() EuclideanInnerProduct() Base.@deprecate insertunit(P::ProductSpace, args...; kwargs...) insertleftunit(args...; kwargs...) +# Factorization structs +@deprecate QR() MatrixAlgebraKit.LAPACK_HouseholderQR() +@deprecate QRpos() MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true) + +@deprecate QL() MatrixAlgebraKit.LAPACK_HouseholderQL() +@deprecate QLpos() MatrixAlgebraKit.LAPACK_HouseholderQL(; positive=true) + +@deprecate LQ() MatrixAlgebraKit.LAPACK_HouseholderLQ() +@deprecate LQpos() MatrixAlgebraKit.LAPACK_HouseholderLQ(; positive=true) + +@deprecate RQ() MatrixAlgebraKit.LAPACK_HouseholderRQ() +@deprecate RQpos() MatrixAlgebraKit.LAPACK_HouseholderRQ(; positive=true) + +@deprecate SDD() MatrixAlgebraKit.LAPACK_DivideAndConquer() +@deprecate SVD() MatrixAlgebraKit.LAPACK_QRIteration() + +@deprecate Polar() MatrixAlgebraKit.PolarViaSVD(MatrixAlgebraKit.LAPACK_DivideAndConquer()) + +# truncations +const TruncationScheme = MatrixAlgebraKit.TruncationStrategy +@deprecate truncdim(d::Int) truncrank(d) +@deprecate truncbelow(ϵ::Real) trunctol(ϵ) + +# factorizations +# -------------- +_kindof(::MatrixAlgebraKit.LAPACK_HouseholderQR) = :qr +_kindof(::MatrixAlgebraKit.LAPACK_HouseholderLQ) = :lq +_kindof(::MatrixAlgebraKit.LAPACK_SVDAlgorithm) = :svd +_kindof(::MatrixAlgebraKit.PolarViaSVD) = :polar + +_drop_alg(; alg=nothing, kwargs...) = kwargs +_drop_p(; p=nothing, kwargs...) = kwargs + +# orthogonalization +export leftorth, leftorth!, rightorth, rightorth! +function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) + return leftorth!(permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p); kwargs...) +end +function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) + return rightorth!(permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p); kwargs...) +end +function leftorth(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) + return leftorth!(copy_oftype(t, factorisation_scalartype(leftorth, t)); kwargs...) +end +function rightorth(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) + return rightorth!(copy_oftype(t, factorisation_scalartype(rightorth, t)); kwargs...) +end +function leftorth!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftorth!` is no longer supported, use `left_orth!` instead", :leftorth!) + haskey(kwargs, :alg) || return left_orth!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return left_orth!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :qr && return left_orth!(t; kind, alg_qr=alg, _drop_alg(; kwargs...)...) + kind === :polar && return left_orth!(t; kind, alg_polar=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid leftorth kind")) +end +function rightorth!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightorth!` is no longer supported, use `right_orth!` instead", :rightorth!) + haskey(kwargs, :alg) || return right_orth!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return right_orth!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :lq && return right_orth!(t; kind, alg_lq=alg, _drop_alg(; kwargs...)...) + kind === :polar && return right_orth!(t; kind, alg_polar=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid rightorth kind")) +end + +# nullspaces +export leftnull, leftnull!, rightnull, rightnull! +function leftnull(t::AbstractTensorMap; kwargs...) + Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) + return leftnull!(copy_oftype(t, factorisation_scalartype(leftnull, t)); kwargs...) +end +function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) + return leftnull!(permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p); kwargs...) +end +function rightnull(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) + return rightnull!(copy_oftype(t, factorisation_scalartype(rightnull, t)); kwargs...) +end +function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) + Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) + return rightnull!(permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p); kwargs...) +end +function leftnull!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`left_null!` is no longer supported, use `left_null!` instead", :leftnull!) + haskey(kwargs, :alg) || return left_null!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return left_null!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :qr && return left_null!(t; kind, alg_qr=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid leftnull kind")) +end +function rightnull!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`rightnull!` is no longer supported, use `right_null!` instead", :rightnull!) + haskey(kwargs, :alg) || return right_null!(t; kwargs...) + alg = kwargs[:alg] + kind = _kindof(alg) + kind === :svd && return right_null!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) + kind === :lq && return right_null!(t; kind, alg_lq=alg, _drop_alg(; kwargs...)...) + throw(ArgumentError("invalid rightnull kind")) +end + +# eigen values +export eig!, eigh!, eigen, eigen! +@deprecate(eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + eig!(permutedcopy_oftype(t, factorisation_scalartype(eig, t), p); kwargs...)) +@deprecate(eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + eigh!(permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p); kwargs...)) +@deprecate(LinearAlgebra.eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + eigen!(permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p); kwargs...), + false) +function eig(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eig` is no longer supported, use `eig_full` or `eig_trunc` instead", :eig) + return haskey(kwargs, :trunc) ? eig_trunc(t; kwargs...) : eig_full(t; kwargs...) +end +function eig!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eig!` is no longer supported, use `eig_full!` or `eig_trunc!` instead", :eig!) + return haskey(kwargs, :trunc) ? eig_trunc!(t; kwargs...) : eig_full!(t; kwargs...) +end +function eigh(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eigh` is no longer supported, use `eigh_full` or `eigh_trunc` instead", :eigh) + return haskey(kwargs, :trunc) ? eigh_trunc(t; kwargs...) : eigh_full(t; kwargs...) +end +function eigh!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`eigh!` is no longer supported, use `eigh_full!` or `eigh_trunc!` instead", :eigh!) + return haskey(kwargs, :trunc) ? eigh_trunc!(t; kwargs...) : eigh_full!(t; kwargs...) +end + +# singular values +export tsvd, tsvd! +@deprecate(tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...), + tsvd!(permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p); kwargs...)) +function tsvd(t::AbstractTensorMap; kwargs...) + Base.depwarn("`tsvd` is no longer supported, use `svd_compact`, `svd_full` or `svd_trunc` instead", :tsvd) + if haskey(kwargs, :p) + Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", :tsvd) + kwargs = _drop_p(; kwargs...) + end + return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : svd_compact(t; kwargs...) +end +function tsvd!(t::AbstractTensorMap; kwargs...) + Base.depwarn("`tsvd!` is no longer supported, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", :tsvd!) + if haskey(kwargs, :p) + Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", :tsvd!) + kwargs = _drop_p(; kwargs...) + end + return haskey(kwargs, :trunc) ? svd_trunc!(t; kwargs...) : svd_compact!(t; kwargs...) +end + #! format: on diff --git a/src/tensors/factorizations/deprecations.jl b/src/tensors/factorizations/deprecations.jl deleted file mode 100644 index a52383225..000000000 --- a/src/tensors/factorizations/deprecations.jl +++ /dev/null @@ -1,167 +0,0 @@ -# Factorization structs -@deprecate QR() LAPACK_HouseholderQR() -@deprecate QRpos() LAPACK_HouseholderQR(; positive=true) - -@deprecate QL() LAPACK_HouseholderQL() -@deprecate QLpos() LAPACK_HouseholderQL(; positive=true) - -@deprecate LQ() LAPACK_HouseholderLQ() -@deprecate LQpos() LAPACK_HouseholderLQ(; positive=true) - -@deprecate RQ() LAPACK_HouseholderRQ() -@deprecate RQpos() LAPACK_HouseholderRQ(; positive=true) - -@deprecate SDD() LAPACK_DivideAndConquer() -@deprecate SVD() LAPACK_QRIteration() - -@deprecate Polar() PolarViaSVD(LAPACK_DivideAndConquer()) - -# truncations -const TruncationScheme = TruncationStrategy -@deprecate truncdim(d::Int) truncrank(d) -@deprecate truncbelow(ϵ::Real) trunctol(ϵ) - -# factorizations -# -------------- -_kindof(::LAPACK_HouseholderQR) = :qr -_kindof(::LAPACK_HouseholderLQ) = :lq -_kindof(::LAPACK_SVDAlgorithm) = :svd -_kindof(::PolarViaSVD) = :polar - -_drop_alg(; alg=nothing, kwargs...) = kwargs -_drop_p(; p=nothing, kwargs...) = kwargs - -# orthogonalization -export leftorth, leftorth!, rightorth, rightorth! -function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) - return leftorth!(permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p); - kwargs...) -end -function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) - return rightorth!(permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p); - kwargs...) -end -function leftorth(t::AbstractTensorMap; kwargs...) - Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) - return leftorth!(copy_oftype(t, factorisation_scalartype(leftorth, t)); kwargs...) -end -function rightorth(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) - return rightorth!(copy_oftype(t, factorisation_scalartype(rightorth, t)); kwargs...) -end -function leftorth!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`leftorth!` is no longer supported, use `left_orth!` instead", :leftorth!) - haskey(kwargs, :alg) || return left_orth!(t; kwargs...) - alg = kwargs[:alg] - kind = _kindof(alg) - kind === :svd && return left_orth!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) - kind === :qr && return left_orth!(t; kind, alg_qr=alg, _drop_alg(; kwargs...)...) - kind === :polar && return left_orth!(t; kind, alg_polar=alg, _drop_alg(; kwargs...)...) - throw(ArgumentError("invalid leftorth kind")) -end -function rightorth!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightorth!` is no longer supported, use `right_orth!` instead", - :rightorth!) - haskey(kwargs, :alg) || return right_orth!(t; kwargs...) - alg = kwargs[:alg] - kind = _kindof(alg) - kind === :svd && return right_orth!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) - kind === :lq && return right_orth!(t; kind, alg_lq=alg, _drop_alg(; kwargs...)...) - kind === :polar && return right_orth!(t; kind, alg_polar=alg, _drop_alg(; kwargs...)...) - throw(ArgumentError("invalid rightorth kind")) -end - -# nullspaces -export leftnull, leftnull!, rightnull, rightnull! -function leftnull(t::AbstractTensorMap; kwargs...) - Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) - return leftnull!(copy_oftype(t, factorisation_scalartype(leftnull, t)); kwargs...) -end -function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) - return leftnull!(permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p); kwargs...) -end -function rightnull(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) - return rightnull!(copy_oftype(t, factorisation_scalartype(rightnull, t)); kwargs...) -end -function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) - return rightnull!(permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p); kwargs...) -end -function leftnull!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`left_null!` is no longer supported, use `left_null!` instead", - :leftnull!) - haskey(kwargs, :alg) || return left_null!(t; kwargs...) - alg = kwargs[:alg] - kind = _kindof(alg) - kind === :svd && return left_null!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) - kind === :qr && return left_null!(t; kind, alg_qr=alg, _drop_alg(; kwargs...)...) - throw(ArgumentError("invalid leftnull kind")) -end -function rightnull!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightnull!` is no longer supported, use `right_null!` instead", - :rightnull!) - haskey(kwargs, :alg) || return right_null!(t; kwargs...) - alg = kwargs[:alg] - kind = _kindof(alg) - kind === :svd && return right_null!(t; kind, alg_svd=alg, _drop_alg(; kwargs...)...) - kind === :lq && return right_null!(t; kind, alg_lq=alg, _drop_alg(; kwargs...)...) - throw(ArgumentError("invalid rightnull kind")) -end - -# eigen values -export eig, eig!, eigh, eigh!, eigen, eigen! -@deprecate(eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - eig!(permutedcopy_oftype(t, factorisation_scalartype(eig, t), p); kwargs...)) -@deprecate(eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - eigh!(permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p); kwargs...)) -@deprecate(eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - eigen!(permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p); kwargs...)) -function eig(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eig` is no longer supported, use `eig_full` or `eig_trunc` instead", - :eig) - return haskey(kwargs, :trunc) ? eig_trunc(t; kwargs...) : eig_full(t; kwargs...) -end -function eig!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eig!` is no longer supported, use `eig_full!` or `eig_trunc!` instead", - :eig!) - return haskey(kwargs, :trunc) ? eig_trunc!(t; kwargs...) : eig_full!(t; kwargs...) -end -function eigh(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eigh` is no longer supported, use `eigh_full` or `eigh_trunc` instead", - :eigh) - return haskey(kwargs, :trunc) ? eigh_trunc(t; kwargs...) : eigh_full(t; kwargs...) -end -function eigh!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eigh!` is no longer supported, use `eigh_full!` or `eigh_trunc!` instead", - :eigh!) - return haskey(kwargs, :trunc) ? eigh_trunc!(t; kwargs...) : eigh_full!(t; kwargs...) -end - -# singular values -export tsvd, tsvd! -@deprecate(tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...), - tsvd!(permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p); kwargs...)) -function tsvd(t::AbstractTensorMap; kwargs...) - Base.depwarn("`tsvd` is no longer supported, use `svd_compact`, `svd_full` or `svd_trunc` instead", - :tsvd) - if haskey(kwargs, :p) - Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", - :tsvd) - kwargs = _drop_p(; kwargs...) - end - return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : svd_compact(t; kwargs...) -end -function tsvd!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`tsvd!` is no longer supported, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", - :tsvd!) - if haskey(kwargs, :p) - Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", - :tsvd!) - kwargs = _drop_p(; kwargs...) - end - return haskey(kwargs, :trunc) ? svd_trunc!(t; kwargs...) : svd_compact!(t; kwargs...) -end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 7718a44b9..ac230ac43 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -44,7 +44,6 @@ import MatrixAlgebraKit: default_algorithm, include("utility.jl") include("matrixalgebrakit.jl") include("truncation.jl") -include("deprecations.jl") include("adjoint.jl") include("diagonal.jl") From 96630e57c65f4a6ff4331f594b33a1227e2c930c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 18:57:00 -0400 Subject: [PATCH 095/126] type stability of rank --- src/tensors/linalg.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index c569cc58f..7ab6c2988 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -289,11 +289,17 @@ _default_rtol(t) = eps(real(float(scalartype(t)))) * min(dim(domain(t)), dim(cod function LinearAlgebra.rank(t::AbstractTensorMap; atol::Real=0, rtol::Real=atol > 0 ? 0 : _default_rtol(t)) - init = dim(one(sectortype(t))) * 0 - dim(t) == 0 && return init + r = dim(one(sectortype(t))) * 0 + dim(t) == 0 && return r S = LinearAlgebra.svdvals(t) tol = max(atol, rtol * maximum(first, values(S))) - return sum(((c, b),) -> dim(c) * count(>(tol), b), S; init) + for (c, b) in S + if !isempty(b) + r += dim(c) * count(>(tol), b) + end + end + return r + # return sum(((c, b),) -> dim(c) * count(>(tol), b), S; init) end function LinearAlgebra.cond(t::AbstractTensorMap, p::Real=2) From 1084853f084b0a09e1161a29b16923e135fa5ed0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 18:57:09 -0400 Subject: [PATCH 096/126] formatter --- src/auxiliary/deprecate.jl | 2 ++ .../factorizations/matrixalgebrakit.jl | 2 +- test/factorizations.jl | 19 +++++++++---------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index 7661c2cf7..163c4933c 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -1,5 +1,7 @@ import Base: transpose +#! format: off + for f in (:rand, :randn, :zeros, :ones) @eval begin Base.@deprecate TensorMap(::typeof($f), T::Type, P::HomSpace) $f(T, P) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 7704dd5f8..8bfc268e0 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -4,7 +4,7 @@ for f in [:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, :left_polar, :right_polar] - f! = Symbol(f, :!) + f! = Symbol(f, :!) @eval function default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T<:AbstractTensorMap} return default_algorithm($f!, blocktype(T); kwargs...) diff --git a/test/factorizations.jl b/test/factorizations.jl index d56bd8329..ad9015430 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -177,11 +177,11 @@ for V in spacelist v, c = @constinferred left_orth(t; kind=:svd) @test v * c ≈ t @test isisometry(v) - + N = @constinferred left_null(t; kind=:svd) @test isisometry(N) @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) - + Nᴴ = @constinferred right_null(t; kind=:svd) @test isisometry(Nᴴ; side=:right) @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) @@ -204,12 +204,12 @@ for V in spacelist for T in eltypes, t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)') d, v = @constinferred eig_full(t) @test t * v ≈ v * d - + d′ = LinearAlgebra.diag(d) for (c, b) in LinearAlgebra.eigvals(t) @test sort(b; by=abs) ≈ sort(d′[c]; by=abs) end - + vdv = v' * v vdv = (vdv + vdv') / 2 @test @constinferred isposdef(vdv) @@ -218,8 +218,7 @@ for V in spacelist d, v = @constinferred eig_trunc(t; trunc=truncrank(dim(domain(t)) ÷ 2)) @test t * v ≈ v * d @test dim(domain(d)) ≤ dim(domain(t)) ÷ 2 - - + t2 = (t + t') D, V = eigen(t2) @test isisometry(V) @@ -232,13 +231,13 @@ for V in spacelist @test isposdef(t2) == isposdef(λ) @test isposdef(t2 - λ * one(t2) + 0.1 * one(t2)) @test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2)) - + add!(t, t') d, v = @constinferred eigh_full(t) @test t * v ≈ v * d @test isunitary(v) - + λ = minimum(minimum(real(LinearAlgebra.diag(b))) for (c, b) in blocks(d)) @test cond(v) ≈ one(real(T)) @test isposdef(t) == isposdef(λ) @@ -250,7 +249,7 @@ for V in spacelist @test dim(domain(d)) ≤ dim(domain(t)) ÷ 2 end end - + @testset "Condition number and rank" begin for T in eltypes, t in (rand(T, W, W), rand(T, W, W)', @@ -268,7 +267,7 @@ for V in spacelist u = unitary(T, V1 ⊗ V2, V1 ⊗ V2) @test @constinferred(cond(u)) ≈ one(real(T)) @test @constinferred(rank(u)) == dim(V1 ⊗ V2) - + t = rand(T, zero(V1), W) @test rank(t) == 0 t2 = rand(T, zero(V1) * zero(V2), zero(V1) * zero(V2)) From 46807e59c8151e2324cbaecab4f3507cae597e52 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 19:07:26 -0400 Subject: [PATCH 097/126] small docs update --- docs/src/lib/tensors.md | 59 ++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index 51b9bcf13..06f5da3d1 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -214,26 +214,49 @@ contract! ## `TensorMap` factorizations -The factorisation methods come in two flavors, namely a non-destructive version where you -can specify an additional permutation of the domain and codomain indices before the -factorisation is performed (provided that `sectorstyle(t)` has a symmetric braiding) as -well as a destructive version The non-destructive methods are given first: - -```@docs -leftorth -rightorth -leftnull -rightnull -tsvd -eigh -eig -eigen +The factorisation methods are powered by [MatrixAlgebraKit.jl](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl) +and all follow the same strategy. The idea is that the `TensorMap` is interpreted as a linear +map based on the current partition of indices between `domain` and `codomain`, and then the +entire range of MatrixAlgebraKit functions can be called. +You can specify an additional permutation of the domain and codomain indices before the +factorisation is performed by making use of [`permute`](@ref) or [`transpose`](@ref), + +```@docs +left_orth +right_orth +left_null +right_null +svd_compact +svd_full +svd_vals +eig_full +eig_vals +eigh_full +eigh_vals isposdef ``` -The corresponding destructive methods have an exclamation mark at the end of their name, -and only accept the `TensorMap` object as well as the method-specific algorithm and keyword -arguments. +Additionally, it is possible to obtain truncated versions of some of these factorizations +through the [`MatrixAlgebraKit.TruncationStrategy`](@ref) objects. + +```@docs +svd_trunc +eig_trunc +eigh_trunc +``` + +The exact truncation strategy can be controlled through one of the following strategies: +```@docs +notrunc +trunctol +truncrank +truncerr +truncspace +``` -TODO: document svd truncation types +It is additionally possible to combine multiple strategies through `&`, e.g. + +```julia +combined_truncation = trunctol(; atol=1e-2) & truncrank(3) +``` From 4300722835670320a59e0196d1e732b6dced9372 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 24 Sep 2025 19:26:05 -0400 Subject: [PATCH 098/126] small fixes for diagonal --- src/auxiliary/deprecate.jl | 1 + src/tensors/factorizations/diagonal.jl | 7 ++ src/tensors/factorizations/factorizations.jl | 2 +- test/diagonal.jl | 87 -------------------- test/factorizations.jl | 23 ++++-- 5 files changed, 25 insertions(+), 95 deletions(-) diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index 163c4933c..6417cb1ec 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -64,6 +64,7 @@ _kindof(::MatrixAlgebraKit.LAPACK_HouseholderQR) = :qr _kindof(::MatrixAlgebraKit.LAPACK_HouseholderLQ) = :lq _kindof(::MatrixAlgebraKit.LAPACK_SVDAlgorithm) = :svd _kindof(::MatrixAlgebraKit.PolarViaSVD) = :polar +_kindof(::DiagonalAlgorithm) = :svd # shouldn't really matter _drop_alg(; alg=nothing, kwargs...) = kwargs _drop_p(; p=nothing, kwargs...) = kwargs diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index 1f4445641..589feef1b 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -49,6 +49,13 @@ for f! in (:lq_full!, :lq_compact!) end end +function initialize_output(::typeof(left_orth!), d::AdjointTensorMap) + return d, similar(d) +end +function initialize_output(::typeof(right_orth!), d::AdjointTensorMap) + return similar(d), d +end + for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!, :right_orth!, :left_orth!) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index ac230ac43..a9f53e9bf 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -11,7 +11,7 @@ export qr_full, qr_compact, qr_null export qr_full!, qr_compact!, qr_null! export lq_full, lq_compact, lq_null export lq_full!, lq_compact!, lq_null! -export copy_oftype, permutedcopy_oftype, one! +export copy_oftype, permutedcopy_oftype, factorisation_scalartype, one! export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD using ..TensorKit diff --git a/test/diagonal.jl b/test/diagonal.jl index 68f6cedc6..8d20c6ca0 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -185,93 +185,6 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @planar E2[-1 -2 -3; -4 -5] = B[-1 -2 1; -4 -5] * t'[-3; 1] @test E1 ≈ E2 end - @timedtestset "Factorization" begin - for T in (Float32, ComplexF64) - t = DiagonalTensorMap(rand(T, reduceddim(V)), V) - @testset "eig/eigh" begin - D, W = @constinferred eig_full(t) - @test t * W ≈ W * D - D, W = @constinferred eig(t) - @test t * W ≈ W * D - t2 = t + t' - D2, V2 = @constinferred eigh(t2) - VdV2 = V2' * V2 - @test VdV2 ≈ one(VdV2) - @test t2 * V2 ≈ V2 * D2 - - D3 = @constinferred eigh_vals(t2) - @test D2 ≈ D3 - - @test rank(D) ≈ rank(t) - @test cond(D) ≈ cond(t) - @test all(((s, t),) -> isapprox(s, t), - zip(values(LinearAlgebra.eigvals(D)), - values(LinearAlgebra.eigvals(t)))) - D, W = @constinferred eig!(t) - @test D === t - @test W == one(t) - @test t * W ≈ W * D - D2, V2 = @constinferred eigh!(t2) - if T <: Real - @test D2 === t2 - end - @test V2 == one(t) - @test t2 * V2 ≈ V2 * D2 - end - @testset "leftorth with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) - Q, R = @constinferred leftorth(t; alg=alg) - QdQ = Q' * Q - @test QdQ ≈ one(QdQ) - @test Q * R ≈ t - if alg isa Polar - @test isposdef(R) - end - end - @testset "rightorth with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) - L, Q = @constinferred rightorth(t; alg=alg) - QQd = Q * Q' - @test QQd ≈ one(QQd) - @test L * Q ≈ t - if alg isa Polar - @test isposdef(L) - end - end - @testset "qr_full with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) - Q, R = @constinferred qr_full(t; alg=alg) - @test isisometry(Q) - @test Q * R ≈ t - end - @testset "qr_compact with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) - Q, R = @constinferred qr_compact(t; alg=alg) - @test isisometry(Q) - @test Q * R ≈ t - end - @testset "lq_full with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) - L, Q = @constinferred lq_full(t; alg=alg) - @test isisometry(Q; side=:right) - @test L * Q ≈ t - end - @testset "lq_compact with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) - L, Q = @constinferred lq_compact(t; alg=alg) - @test isisometry(Q; side=:right) - @test L * Q ≈ t - end - @testset "tsvd with $alg" for alg in (TensorKit.DiagonalAlgorithm(),) - U, S, Vᴴ = @constinferred tsvd(t; alg=alg) - UdU = U' * U - @test UdU ≈ one(UdU) - VdV = Vᴴ * Vᴴ' - @test VdV ≈ one(VdV) - @test U * S * Vᴴ ≈ t - - @test rank(S) ≈ rank(t) - @test cond(S) ≈ cond(t) - @test all(((s, t),) -> isapprox(s, t), - zip(values(LinearAlgebra.svdvals(S)), - values(LinearAlgebra.svdvals(t)))) - end - end - end @timedtestset "Tensor functions" begin for T in (Float64, ComplexF64) d = DiagonalTensorMap(rand(T, reduceddim(V)), V) diff --git a/test/factorizations.jl b/test/factorizations.jl index ad9015430..daf178194 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -28,7 +28,8 @@ for V in spacelist @testset "QR decomposition" begin for T in eltypes, - t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)') + t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) Q, R = @constinferred qr_full(t) @test Q * R ≈ t @@ -78,7 +79,8 @@ for V in spacelist @testset "LQ decomposition" begin for T in eltypes, - t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)') + t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) L, Q = @constinferred lq_full(t) @test L * Q ≈ t @@ -124,7 +126,8 @@ for V in spacelist @testset "Polar decomposition" begin for T in eltypes, - t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)') + t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) @assert domain(t) ≾ codomain(t) w, p = @constinferred left_polar(t) @@ -156,7 +159,8 @@ for V in spacelist for T in eltypes, t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W), - rand(T, W, V1)', rand(T, V1, W)') + rand(T, W, V1)', rand(T, V1, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) u, s, vᴴ = @constinferred svd_full(t) @test u * s * vᴴ ≈ t @@ -201,7 +205,11 @@ for V in spacelist end @testset "Eigenvalue decomposition" begin - for T in eltypes, t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)') + for T in eltypes, + t in + (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) + d, v = @constinferred eig_full(t) @test t * v ≈ v * d @@ -213,7 +221,7 @@ for V in spacelist vdv = v' * v vdv = (vdv + vdv') / 2 @test @constinferred isposdef(vdv) - @test !isposdef(t) # unlikely for non-hermitian map + t isa DiagonalTensorMap || @test !isposdef(t) # unlikely for non-hermitian map d, v = @constinferred eig_trunc(t; trunc=truncrank(dim(domain(t)) ÷ 2)) @test t * v ≈ v * d @@ -254,7 +262,8 @@ for V in spacelist for T in eltypes, t in (rand(T, W, W), rand(T, W, W)', rand(T, W, V1), rand(T, V1, W), - rand(T, W, V1)', rand(T, V1, W)') + rand(T, W, V1)', rand(T, V1, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) d1, d2 = dim(codomain(t)), dim(domain(t)) @test rank(t) == min(d1, d2) From 71a6ba2cf7b3b94dcfbeb0c2148544a05e8ca931 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 17:34:29 -0400 Subject: [PATCH 099/126] remove more unused files --- src/auxiliary/linalg.jl | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 src/auxiliary/linalg.jl diff --git a/src/auxiliary/linalg.jl b/src/auxiliary/linalg.jl deleted file mode 100644 index c39a6b435..000000000 --- a/src/auxiliary/linalg.jl +++ /dev/null @@ -1,4 +0,0 @@ -# Simple reference to getting and setting BLAS threads -#------------------------------------------------------ -set_num_blas_threads(n::Integer) = LinearAlgebra.BLAS.set_num_threads(n) -get_num_blas_threads() = LinearAlgebra.BLAS.get_num_threads() From 783256a0ea669879325ff38139c39e9e38abf260 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 17:48:02 -0400 Subject: [PATCH 100/126] fix some diagonal edge cases --- src/tensors/factorizations/diagonal.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index 589feef1b..b971d94d2 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -49,13 +49,22 @@ for f! in (:lq_full!, :lq_compact!) end end -function initialize_output(::typeof(left_orth!), d::AdjointTensorMap) +function initialize_output(::typeof(left_orth!), d::DiagonalTensorMap) return d, similar(d) end -function initialize_output(::typeof(right_orth!), d::AdjointTensorMap) +function initialize_output(::typeof(right_orth!), d::DiagonalTensorMap) return similar(d), d end +function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::DiagonalAlgorithm) + V_cod = fuse(codomain(t)) + V_dom = fuse(domain(t)) + U = similar(t, codomain(t) ← V_cod) + S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod ← V_dom) + Vᴴ = similar(t, V_dom ← domain(t)) + return U, S, Vᴴ +end + for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!, :right_orth!, :left_orth!) From f31fae4b6057c753dcfdd448ae6f60cdd051caee Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 18:56:00 -0400 Subject: [PATCH 101/126] rework orths to not take allocate output first --- src/tensors/factorizations/factorizations.jl | 2 +- .../factorizations/matrixalgebrakit.jl | 72 ++++++++++++++++--- 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index a9f53e9bf..14d5832dc 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -36,7 +36,7 @@ import MatrixAlgebraKit: default_algorithm, eigh_full!, eigh_trunc!, eigh_vals!, eig_full!, eig_trunc!, eig_vals!, left_polar!, left_orth_polar!, right_polar!, right_orth_polar!, - left_null_svd!, right_null_svd!, + left_null_svd!, right_null_svd!, left_orth_svd!, right_orth_svd!, left_orth!, right_orth!, left_null!, right_null!, truncate!, findtruncated, findtruncated_sorted, diagview, isisometry diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 8bfc268e0..1073a6430 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -553,16 +553,6 @@ function initialize_output(::typeof(right_polar!), t::AbstractTensorMap, return P, Wᴴ end -# Needed to get algorithm selection to behave -function left_orth_polar!(t::AbstractTensorMap, VC, alg) - alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg) - return left_orth_polar!(t, VC, alg′) -end -function right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg) - alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg) - return right_orth_polar!(t, CVᴴ, alg′) -end - # Orthogonalization # ----------------- function check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, ::AbstractAlgorithm) @@ -609,6 +599,68 @@ function initialize_output(::typeof(right_orth!), t::AbstractTensorMap) return C, Vᴴ end +# This is a rework of the dispatch logic in order to avoid having to deal with having to +# allocate the output before knowing the kind of decomposition. In particular, here I disable +# providing output arguments for left_ and right_orth. +# This is mainly because polar decompositions have different shapes, and SVD for Diagonal +# also does +function left_orth!(t::AbstractTensorMap; + trunc::TruncationStrategy=notrunc(), + kind=trunc == notrunc() ? :qr : :svd, + alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;)) + trunc == notrunc() || kind === :svd || + throw(ArgumentError("truncation not supported for left_orth with kind = $kind")) + + kind === :qr && return qr_compact!(t; alg_qr...) + kind === :polar && return left_orth_polar!(t; alg_polar...) + kind === :svd && return left_orth_svd!(t; trunc, alg_svd...) + + throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`")) +end +function right_orth!(t::AbstractTensorMap; + trunc::TruncationStrategy=notrunc(), + kind=trunc == notrunc() ? :lq : :svd, + alg_lq=(; positive=true), alg_polar=(;), alg_svd=(;)) + trunc == notrunc() || kind === :svd || + throw(ArgumentError("truncation not supported for right_orth with kind = $kind")) + + kind === :lq && return lq_compact!(t; alg_lq...) + kind === :polar && return right_orth_polar!(t; alg_polar...) + kind === :svd && return right_orth_svd!(t; trunc, alg_svd...) + + throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`")) +end + +function left_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...) + alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg; kwargs...) + VC = initialize_output(left_orth!, t) + return left_orth_polar!(t, VC, alg′) +end +function left_orth_polar!(t::AbstractTensorMap, VC, alg) + alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg) + return left_orth_polar!(t, VC, alg′) +end +function right_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...) + alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg; kwargs...) + CVᴴ = initialize_output(right_orth!, t) + return right_orth_polar!(t, CVᴴ, alg′) +end +function right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg) + alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg) + return right_orth_polar!(t, CVᴴ, alg′) +end + +function left_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) + U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) : + svd_trunc!(t; trunc, kwargs...) + return U, lmul!(S, Vᴴ) +end +function right_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) + U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) : + svd_trunc!(t; trunc, kwargs...) + return rmul!(U, S), Vᴴ +end + # Nullspace # --------- function check_input(::typeof(left_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) From 27e947093de2840a5f68193f19ccc201f5467f34 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 19:02:09 -0400 Subject: [PATCH 102/126] more cleanup --- src/tensors/factorizations/factorizations.jl | 6 ------ test/factorizations.jl | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 14d5832dc..f2458a31d 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -54,12 +54,6 @@ function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) return isisometry(t) end -# Orthogonal factorizations (mutation for recycling memory): -# only possible if scalar type is floating point -# only correct if Euclidean inner product -#------------------------------------------------------------------------------------------ -const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} - #------------------------------# # LinearAlgebra overloads #------------------------------# diff --git a/test/factorizations.jl b/test/factorizations.jl index daf178194..363cf92d4 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -230,7 +230,7 @@ for V in spacelist t2 = (t + t') D, V = eigen(t2) @test isisometry(V) - D̃, Ṽ = @constinferred eigh(t2) + D̃, Ṽ = @constinferred eigh_full(t2) @test D ≈ D̃ @test V ≈ Ṽ λ = minimum(minimum(real(LinearAlgebra.diag(b))) From 3a274b76204515723fd100e5223477a60e265e7c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 19:30:41 -0400 Subject: [PATCH 103/126] more fixes for docs --- docs/Project.toml | 3 +- docs/make.jl | 11 ++++++-- docs/src/lib/tensors.md | 36 +++--------------------- src/TensorKit.jl | 1 + src/tensors/factorizations/truncation.jl | 14 +++++++++ 5 files changed, 29 insertions(+), 36 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index bb471e442..fc67550fd 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,10 +1,9 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" -TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f" [compat] Documenter = "1" Random = "1" -TensorKitSectors = "0.1" diff --git a/docs/make.jl b/docs/make.jl index 2dab6cbdd..1ad55321d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,12 @@ using Documenter using Random -using TensorKit, TensorKitSectors +using TensorKit +using TensorKit.TensorKitSectors +using TensorKit.MatrixAlgebraKit +using DocumenterInterLinks + +links = InterLinks("MatrixAlgebraKit" => "https://quantumkithub.github.io/MatrixAlgebraKit.jl/stable/", + "TensorOperations" => "https://quantumkithub.github.io/TensorOperations.jl/stable/") pages = ["Home" => "index.md", "Manual" => ["man/intro.md", "man/tutorial.md", "man/categories.md", @@ -15,6 +21,7 @@ makedocs(; modules=[TensorKit, TensorKitSectors], format=Documenter.HTML(; prettyurls=true, mathengine=MathJax(), assets=["assets/custom.css"]), pages=pages, - pagesonly=true) + pagesonly=true, + plugins=[links]) deploydocs(; repo="github.com/QuantumKitHub/TensorKit.jl.git", push_preview=true) diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index 06f5da3d1..167a83adc 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -219,44 +219,16 @@ and all follow the same strategy. The idea is that the `TensorMap` is interprete map based on the current partition of indices between `domain` and `codomain`, and then the entire range of MatrixAlgebraKit functions can be called. You can specify an additional permutation of the domain and codomain indices before the -factorisation is performed by making use of [`permute`](@ref) or [`transpose`](@ref), +factorisation is performed by making use of [`permute`](@ref) or [`transpose`](@ref). -```@docs -left_orth -right_orth -left_null -right_null -svd_compact -svd_full -svd_vals -eig_full -eig_vals -eigh_full -eigh_vals -isposdef -``` +For the full list of factorizations, see [Decompositions](@extref MatrixAlgebraKit). Additionally, it is possible to obtain truncated versions of some of these factorizations through the [`MatrixAlgebraKit.TruncationStrategy`](@ref) objects. -```@docs -svd_trunc -eig_trunc -eigh_trunc -``` - -The exact truncation strategy can be controlled through one of the following strategies: +The exact truncation strategy can be controlled through the strategies defined in [Truncations](@extref MatrixAlgebraKit), +but for `TensorMap`s there is also the special-purpose scheme: ```@docs -notrunc -trunctol -truncrank -truncerr truncspace ``` - -It is additionally possible to combine multiple strategies through `&`, e.g. - -```julia -combined_truncation = trunctol(; atol=1e-2) & truncrank(3) -``` diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 88467fa4b..8934660d5 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -81,6 +81,7 @@ export left_orth, right_orth, left_null, right_null, eig_trunc, eigh_vals!, eigh_vals, eig_vals!, eig_vals, isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond + # deprecate: export eig, eig!, eigh, eigh!, eigen, eigen!, tsvd, tsvd!, leftorth, leftorth!, rightorth, rightorth!, leftnull, leftnull!, rightnull, rightnull! diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 708e5c252..e19da07f7 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -1,5 +1,8 @@ # Strategies # ---------- +""" + notrunc() +""" notrunc() = NoTruncation() # deprecate @@ -10,11 +13,22 @@ struct TruncationError{T<:Real} <: TruncationStrategy ϵ::T p::Real end + +""" + truncerr(epsilon, p) +""" truncerr(epsilon::Real, p::Real=2) = TruncationError(epsilon, p) struct TruncationSpace{S<:ElementarySpace} <: TruncationStrategy space::S end + +""" + truncspace(space::ElementarySpace) + +Truncation strategy to keep the first values such that the resulting space is the infimum of +the total space and the provided space. +""" truncspace(space::ElementarySpace) = TruncationSpace(space) # Truncation From 19cce9dbbc3e69cd9297ee815555f0f438c711f8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 21:24:58 -0400 Subject: [PATCH 104/126] Fix and refactor truncation tests --- src/tensors/factorizations/truncation.jl | 14 ++++--- test/factorizations.jl | 48 ++++++++++++++++++++++++ test/tensors.jl | 44 ---------------------- 3 files changed, 57 insertions(+), 49 deletions(-) diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index e19da07f7..5b91634ab 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -36,6 +36,7 @@ truncspace(space::ElementarySpace) = TruncationSpace(space) function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ)::Tuple{AbstractTensorMap,AbstractTensorMap,AbstractTensorMap}, strategy::TruncationStrategy) + strategy == notrunc() && return (U, S, Vᴴ) ind = findtruncated_sorted(diagview(S), strategy) V_truncated = spacetype(S)(c => length(I) for (c, I) in ind) @@ -66,6 +67,7 @@ end function truncate!(::typeof(left_null!), (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, strategy::MatrixAlgebraKit.TruncationStrategy) + strategy == notrunc() && return (U, S) extended_S = SectorDict(c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) for (c, b) in blocks(S)) @@ -82,6 +84,7 @@ for f! in (:eig_trunc!, :eigh_trunc!) @eval function truncate!(::typeof($f!), (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, strategy::TruncationStrategy) + strategy == notrunc() && return (D, V) ind = findtruncated(diagview(D), strategy) V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) @@ -136,10 +139,14 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}; by=identity, end # implementations +function findtruncated_sorted(S::SectorDict, strategy::TruncationStrategy) + return findtruncated(S, strategy) +end + function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove) atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol)) - return SectorDict(c => findtrunc(d) for (c, d) in Sd) + return SectorDict(c => findtrunc(d) for (c, d) in S) end function findtruncated(S::SectorDict, strategy::TruncationKeepAbove) atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) @@ -195,13 +202,10 @@ function findtruncated(Sd::SectorDict, strategy::TruncationKeepSorted) _, cmin = next truncdim[cmin] -= 1 totaldim -= dim(cmin) - if totaldim < strategy.howmany - # truncdim[cmin] += 1 - break - end if truncdim[cmin] == 0 delete!(truncdim, cmin) end + totaldim <= strategy.howmany && break end return SectorDict(c => permutations[c][Base.OneTo(d)] for (c, d) in truncdim) end diff --git a/test/factorizations.jl b/test/factorizations.jl index 363cf92d4..71fbbbeb7 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -204,6 +204,54 @@ for V in spacelist end end + @testset "truncated SVD" begin + for T in eltypes, + t in (randn(T, W, W), randn(T, W, W)', + randn(T, W, V1), randn(T, V1, W), + randn(T, W, V1)', randn(T, V1, W)', + DiagonalTensorMap(randn(T, reduceddim(V1)), V1)) + + @constinferred normalize!(t) + + U, S, Vᴴ = @constinferred svd_trunc(t; trunc=notrunc()) + @test U * S * Vᴴ ≈ t + @test isisometry(U) + @test isisometry(Vᴴ; side=:right) + + trunc = truncrank(dim(domain(S)) ÷ 2) + U1, S1, Vᴴ1 = @constinferred svd_trunc(t; trunc) + @test t * Vᴴ1' ≈ U1 * S1 + @test isisometry(U1) + @test isisometry(Vᴴ1; side=:right) + @test dim(domain(S1)) <= trunc.howmany + + λ = minimum(minimum, values(LinearAlgebra.diag(S1))) + trunc = trunctol(λ - 10eps(λ)) + U2, S2, Vᴴ2 = @constinferred svd_trunc(t; trunc) + @test t * Vᴴ2' ≈ U2 * S2 + @test isisometry(U2) + @test isisometry(Vᴴ2; side=:right) + @test minimum(minimum, values(LinearAlgebra.diag(S1))) >= λ + @test U2 ≈ U1 + @test S2 ≈ S1 + @test Vᴴ2 ≈ Vᴴ1 + + trunc = truncspace(space(S2, 1)) + U3, S3, Vᴴ3 = @constinferred svd_trunc(t; trunc) + @test t * Vᴴ3' ≈ U3 * S3 + @test isisometry(U3) + @test isisometry(Vᴴ3; side=:right) + @test space(S3, 1) ≾ space(S2, 1) + + trunc = truncerr(0.5) + U4, S4, Vᴴ4 = @constinferred svd_trunc(t; trunc) + @test t * Vᴴ4' ≈ U4 * S4 + @test isisometry(U4) + @test isisometry(Vᴴ4; side=:right) + @test norm(t - U4 * S4 * Vᴴ4) <= 0.5 + end + end + @testset "Eigenvalue decomposition" begin for T in eltypes, t in diff --git a/test/tensors.jl b/test/tensors.jl index 796becded..f60de0d1b 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -432,50 +432,6 @@ for V in spacelist @test LinearAlgebra.isdiag(D) @test LinearAlgebra.diag(D) == d end - @timedtestset "Tensor truncation" begin - for T in (Float32, ComplexF64) - for p in (1, 2, 3, Inf) - # Test both a normal tensor and an adjoint one. - ts = (randn(T, V1 ⊗ V2 ⊗ V3, V4 ⊗ V5), - randn(T, V4 ⊗ V5, V1 ⊗ V2 ⊗ V3)') - for t in ts - U₀, S₀, V₀, = tsvd(t) - t = rmul!(t, 1 / norm(S₀, p)) - U, S, V = @constinferred tsvd(t; trunc=truncerr(5e-1, p)) - ϵ = TensorKit._norm(LinearAlgebra.svdvals(U * S * V - t), p, - zero(scalartype(S))) - p == 2 && @test ϵ < 5e-1 - # @show p, ϵ - # @show domain(S) - # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′ = tsvd(t; trunc=truncerr(ϵ + 10eps(ϵ), p)) - ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, - zero(scalartype(S))) - - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S))))) - ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, - zero(scalartype(S))) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - U′, S′, V′ = tsvd(t; trunc=truncspace(space(S, 1))) - ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, - zero(scalartype(S))) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - # results with truncationcutoff cannot be compared because they don't take degeneracy into account, and thus truncate differently - U, S, V = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀)))) - ϵ = TensorKit._norm(LinearAlgebra.svdvals(U * S * V - t), p, - zero(scalartype(S))) - # @show p, ϵ - # @show domain(S) - # @test min(space(S,1), space(S₀,1)) != space(S₀,1) - U′, S′, V′ = tsvd(t; trunc=truncspace(space(S, 1))) - ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p, - zero(scalartype(S))) - @test (U, S, V, ϵ) == (U′, S′, V′, ϵ′) - end - end - end - end if BraidingStyle(I) isa Bosonic && hasfusiontensor(I) @timedtestset "Tensor functions" begin W = V1 ⊗ V2 From e5ee802a48c806914f979a32c39f14266cd227c4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 25 Sep 2025 21:59:47 -0400 Subject: [PATCH 105/126] fix AD tests --- .../factorizations.jl | 4 +-- src/TensorKit.jl | 3 -- .../factorizations/matrixalgebrakit.jl | 32 +++++++++++++------ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index 0c6924411..35b78a52a 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -65,7 +65,7 @@ function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTenso end function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) - alg isa TensorKit.QR || alg isa TensorKit.QRpos || + alg isa MatrixAlgebraKit.LAPACK_HouseholderQR || error("only `alg=QR()` and `alg=QRpos()` are supported") QR = leftorth(t; alg) function leftorth!_pullback(ΔQR′) @@ -85,7 +85,7 @@ function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRp end function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) - alg isa TensorKit.LQ || alg isa TensorKit.LQpos || + alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ || error("only `alg=LQ()` and `alg=LQpos()` are supported") LQ = rightorth(t; alg) function rightorth!_pullback(ΔLQ′) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 8934660d5..21191f4bb 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -82,9 +82,6 @@ export left_orth, right_orth, left_null, right_null, eigh_vals!, eigh_vals, eig_vals!, eig_vals, isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond -# deprecate: -export eig, eig!, eigh, eigh!, eigen, eigen!, tsvd, tsvd!, leftorth, leftorth!, rightorth, - rightorth!, leftnull, leftnull!, rightnull, rightnull! export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! export catdomain, catcodomain, absorb, absorb! diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 1073a6430..560a210d5 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -611,11 +611,17 @@ function left_orth!(t::AbstractTensorMap; trunc == notrunc() || kind === :svd || throw(ArgumentError("truncation not supported for left_orth with kind = $kind")) - kind === :qr && return qr_compact!(t; alg_qr...) - kind === :polar && return left_orth_polar!(t; alg_polar...) - kind === :svd && return left_orth_svd!(t; trunc, alg_svd...) - - throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`")) + return if kind === :qr + alg_qr isa NamedTuple ? qr_compact!(t; alg_qr...) : qr_compact!(t; alg=alg_qr) + elseif kind === :polar + alg_polar isa NamedTuple ? left_orth_polar!(t; alg_polar...) : + left_orth_polar!(t; alg=alg_polar) + elseif kind === :svd + alg_svd isa NamedTuple ? left_orth_svd!(t; trunc, alg_svd...) : + left_orth_svd!(t; trunc, alg=alg_svd) + else + throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`")) + end end function right_orth!(t::AbstractTensorMap; trunc::TruncationStrategy=notrunc(), @@ -624,11 +630,17 @@ function right_orth!(t::AbstractTensorMap; trunc == notrunc() || kind === :svd || throw(ArgumentError("truncation not supported for right_orth with kind = $kind")) - kind === :lq && return lq_compact!(t; alg_lq...) - kind === :polar && return right_orth_polar!(t; alg_polar...) - kind === :svd && return right_orth_svd!(t; trunc, alg_svd...) - - throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`")) + return if kind === :lq + alg_lq isa NamedTuple ? lq_compact!(t; alg_lq...) : lq_compact!(t; alg=alg_lq) + elseif kind === :polar + alg_polar isa NamedTuple ? right_orth_polar!(t; alg_polar...) : + right_orth_polar!(t; alg=alg_polar) + elseif kind === :svd + alg_svd isa NamedTuple ? right_orth_svd!(t; trunc, alg_svd...) : + right_orth_svd!(t; trunc, alg=alg_svd) + else + throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`")) + end end function left_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...) From 7c5419babe9e424157bae71d139a6f534c4146cd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 26 Sep 2025 09:46:41 -0400 Subject: [PATCH 106/126] rework truncation to be MatrixAlgebraKit 0.4 compliant --- Project.toml | 2 +- src/TensorKit.jl | 2 +- src/tensors/factorizations/factorizations.jl | 18 +- src/tensors/factorizations/truncation.jl | 188 +++++++++---------- test/factorizations.jl | 4 +- 5 files changed, 104 insertions(+), 110 deletions(-) diff --git a/Project.toml b/Project.toml index 1986e2963..e9790c851 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.3.2" +MatrixAlgebraKit = "0.4.0" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 21191f4bb..668a32976 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -91,7 +91,7 @@ export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor export scalar, add!, contract! # truncation schemes -export notrunc, truncerr, truncrank, truncspace, trunctol +export notrunc, truncrank, trunctol, truncfilter, truncspace, truncerror # cache management export empty_globalcaches! diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index f2458a31d..6f447612a 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -12,7 +12,8 @@ export qr_full!, qr_compact!, qr_null! export lq_full, lq_compact, lq_null export lq_full!, lq_compact!, lq_null! export copy_oftype, permutedcopy_oftype, factorisation_scalartype, one! -export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD +export TruncationScheme, notrunc, trunctol, truncerror, truncrank, truncspace, truncfilter, + PolarViaSVD using ..TensorKit using ..TensorKit: AdjointTensorMap, SectorDict, blocktype, foreachblock, one! @@ -23,12 +24,13 @@ import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian using TensorOperations: Index2Tuple using MatrixAlgebraKit -using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy, - NoTruncation, TruncationKeepAbove, TruncationKeepBelow, - TruncationIntersection, TruncationKeepFiltered, PolarViaSVD, - LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR, - LAPACK_HouseholderLQ, LAPACK_HouseholderQL, LAPACK_HouseholderRQ, - DiagonalAlgorithm +using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm +using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, + TruncationByError, TruncationIntersection, TruncationByFilter, + TruncationByOrder +using MatrixAlgebraKit: PolarViaSVD +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR, + LAPACK_HouseholderLQ, LAPACK_HouseholderQL, LAPACK_HouseholderRQ import MatrixAlgebraKit: default_algorithm, copy_input, check_input, initialize_output, qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, @@ -38,7 +40,7 @@ import MatrixAlgebraKit: default_algorithm, left_polar!, left_orth_polar!, right_polar!, right_orth_polar!, left_null_svd!, right_null_svd!, left_orth_svd!, right_orth_svd!, left_orth!, right_orth!, left_null!, right_null!, - truncate!, findtruncated, findtruncated_sorted, + truncate!, findtruncated, findtruncated_svd, diagview, isisometry include("utility.jl") diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 5b91634ab..5d5178470 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -1,43 +1,40 @@ # Strategies # ---------- -""" - notrunc() -""" -notrunc() = NoTruncation() -# deprecate +# TODO: deprecate const TruncationScheme = TruncationStrategy -# TODO: add this to MatrixAlgebraKit -struct TruncationError{T<:Real} <: TruncationStrategy - ϵ::T - p::Real -end - -""" - truncerr(epsilon, p) """ -truncerr(epsilon::Real, p::Real=2) = TruncationError(epsilon, p) + TruncationSpace(V::ElementarySpace, by::Function, rev::Bool) -struct TruncationSpace{S<:ElementarySpace} <: TruncationStrategy +Truncation strategy to keep the first values for each sector when sorted according to `by` and `rev`, +such that the resulting vector space is no greater than `V`. + +See also [`truncspace`](@ref). +""" +struct TruncationSpace{S<:ElementarySpace,F} <: TruncationStrategy space::S + by::F + rev::Bool end """ - truncspace(space::ElementarySpace) + truncspace(space::ElementarySpace; by=abs, rev::Bool=true) -Truncation strategy to keep the first values such that the resulting space is the infimum of -the total space and the provided space. +Truncation strategy to keep the first values for each sector when sorted according to `by` and `rev`, +such that the resulting vector space is no greater than `V`. """ -truncspace(space::ElementarySpace) = TruncationSpace(space) +function truncspace(space::ElementarySpace; by=abs, rev::Bool=true) + isdual(space) && throw(ArgumentError("resulting vector space is never dual")) + return TruncationSpace(space, by, rev) +end -# Truncation -# ---------- +# truncate! +# --------- function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ)::Tuple{AbstractTensorMap,AbstractTensorMap,AbstractTensorMap}, strategy::TruncationStrategy) - strategy == notrunc() && return (U, S, Vᴴ) - ind = findtruncated_sorted(diagview(S), strategy) + ind = findtruncated_svd(diagview(S), strategy) V_truncated = spacetype(S)(c => length(I) for (c, I) in ind) Ũ = similar(U, codomain(U) ← V_truncated) @@ -67,7 +64,6 @@ end function truncate!(::typeof(left_null!), (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, strategy::MatrixAlgebraKit.TruncationStrategy) - strategy == notrunc() && return (U, S) extended_S = SectorDict(c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) for (c, b) in blocks(S)) @@ -84,7 +80,6 @@ for f! in (:eig_trunc!, :eigh_trunc!) @eval function truncate!(::typeof($f!), (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, strategy::TruncationStrategy) - strategy == notrunc() && return (D, V) ind = findtruncated(diagview(D), strategy) V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) @@ -109,7 +104,7 @@ end # Find truncation # --------------- # auxiliary functions -rtol_to_atol(S, p, atol, rtol) = rtol > 0 ? max(atol, _norm(S, p) * rtol) : atol +rtol_to_atol(S, p, atol, rtol) = rtol > 0 ? max(atol, TensorKit._norm(S, p) * rtol) : atol function _compute_truncerr(Σdata, truncdim, p=2) I = keytype(Σdata) @@ -138,99 +133,96 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}; by=identity, end end -# implementations -function findtruncated_sorted(S::SectorDict, strategy::TruncationStrategy) - return findtruncated(S, strategy) +# findtruncated +# ------------- +# Generic fallback +function findtruncated_svd(values::SectorDict, strategy::TruncationStrategy) + return findtruncated(values, strategy) end -function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove) - atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) - findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol)) - return SectorDict(c => findtrunc(d) for (c, d) in S) -end -function findtruncated(S::SectorDict, strategy::TruncationKeepAbove) - atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) - findtrunc = Base.Fix2(findtruncated, truncbelow(atol)) - return SectorDict(c => findtrunc(d) for (c, d) in Sd) -end - -function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepBelow) - atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) - findtrunc = Base.Fix2(findtruncated_sorted, truncabove(atol)) - return SectorDict(c => findtrunc(d) for (c, d) in Sd) -end -function findtruncated(S::SectorDict, strategy::TruncationKeepBelow) - atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol) - findtrunc = Base.Fix2(findtruncated, truncabove(atol)) - return SectorDict(c => findtrunc(d) for (c, d) in Sd) -end - -function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError) - I = keytype(Sd) - truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd) - while true - next = _findnexttruncvalue(Sd, truncdim) - isnothing(next) && break - σmin, cmin = next - truncdim[cmin] -= 1 - err = _compute_truncerr(Sd, truncdim, strategy.p) - if err > strategy.ϵ - truncdim[cmin] += 1 - break - end - if truncdim[cmin] == 0 - delete!(truncdim, cmin) - end - end - return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) +function findtruncated(values::SectorDict, ::NoTruncation) + return SectorDict(c => Base.OneTo(length(b)) for (c, b) in values) end -function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted) - return findtruncated(Sd, strategy) +function findtruncated(values::SectorDict, strategy::TruncationByOrder) + perms = SectorDict(c => (sortperm(d; strategy.by, strategy.rev)) for (c, d) in values) + values_sorted = SectorDict(c => d[perms[c]] for (c, d) in values) + inds = findtruncated_svd(values_sorted, truncrank(strategy.howmany)) + return SectorDict(c => perms[c][I] for (c, I) in inds) end -function findtruncated(Sd::SectorDict, strategy::TruncationKeepSorted) - permutations = SectorDict(c => (sortperm(d; strategy.by, strategy.rev)) - for (c, d) in Sd) - Sd = SectorDict(c => sort(d; strategy.by, strategy.rev) for (c, d) in Sd) - - I = keytype(Sd) - truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd) +function findtruncated_svd(values::SectorDict, strategy::TruncationByOrder) + I = keytype(values) + truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in values) totaldim = sum(dim(c) * d for (c, d) in truncdim; init=0) while true - next = _findnexttruncvalue(Sd, truncdim; strategy.by, strategy.rev) + next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev) isnothing(next) && break _, cmin = next truncdim[cmin] -= 1 totaldim -= dim(cmin) - if truncdim[cmin] == 0 - delete!(truncdim, cmin) - end + truncdim[cmin] == 0 && delete!(truncdim, cmin) totaldim <= strategy.howmany && break end - return SectorDict(c => permutations[c][Base.OneTo(d)] for (c, d) in truncdim) + return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim) end -function findtruncated_sorted(Sd::SectorDict, strategy::TruncationSpace) - I = keytype(Sd) - return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(min(length(d), - dim(strategy.space, c))) - for (c, d) in Sd) +function findtruncated(values::SectorDict, strategy::TruncationByFilter) + return SectorDict(c => findall(strategy.filter, d) for (c, d) in values) +end + +function findtruncated(values::SectorDict, strategy::TruncationByValue) + atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) + strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) + return SectorDict(c => findtruncated(d, strategy′) for (c, d) in values) +end +function findtruncated_svd(values::SectorDict, strategy::TruncationByValue) + atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) + strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) + return SectorDict(c => findtruncated_svd(d, strategy′) for (c, d) in values) +end + +function findtruncated(values::SectorDict, strategy::TruncationByError) + perms = SectorDict(c => sortperm(d; by=abs, rev=true) for (c, d) in values) + values_sorted = SectorDict(c => d[perms[c]] for (c, d) in Sd) + inds = findtruncated_svd(values_sorted, truncrank(strategy.howmany)) + return SectorDict(c => perms[c][I] for (c, I) in inds) +end +function findtruncated_svd(values::SectorDict, strategy::TruncationByError) + I = keytype(values) + truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in values) + by(c, v) = abs(v)^strategy.p * dim(c) + Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), values) + ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ) + truncerrᵖ = zero(real(scalartype(valtype(values)))) + next = _findnexttruncvalue(values, truncdim) + while !isnothing(next) + σmin, cmin = next + truncerrᵖ += by(cmin, σmin) + truncerrᵖ >= ϵᵖ && break + (truncdim[cmin] -= 1) == 0 && delete!(truncdim, cmin) + next = _findnexttruncvalue(values, truncdim) + end + return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) end -function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepFiltered) - return SectorDict(c => findtruncated_sorted(d, strategy) for (c, d) in Sd) +function findtruncated(values::SectorDict, strategy::TruncationSpace) + blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) + return SectorDict(c => findtruncated(d, blockstrategy(c)) for (c, d) in values) end -function findtruncated(Sd::SectorDict, strategy::TruncationKeepFiltered) - return SectorDict(c => findtruncated(d, strategy) for (c, d) in Sd) +function findtruncated_svd(values::SectorDict, strategy::TruncationSpace) + blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) + return SectorDict(c => findtruncated_svd(d, blockstrategy(c)) for (c, d) in values) end -function findtruncated_sorted(Sd::SectorDict, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated_sorted, Sd), strategy) - return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...) +function findtruncated(values::SectorDict, strategy::TruncationIntersection) + inds = map(Base.Fix1(findtruncated, values), strategy) + return SectorDict(c => mapreduce(Base.Fix2(getindex, c), _ind_intersect, inds; + init=trues(length(values[c]))) for c in intersect(map(keys, inds)...)) end -function findtruncated(Sd::SectorDict, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated, Sd), strategy) - return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...) +function findtruncated_svd(Sd::SectorDict, strategy::TruncationIntersection) + inds = map(Base.Fix1(findtruncated_svd, Sd), strategy) + return SectorDict(c => mapreduce(Base.Fix2(getindex, c), _ind_intersect, inds; + init=trues(length(values[c]))) for c in intersect(map(keys, inds)...)) end diff --git a/test/factorizations.jl b/test/factorizations.jl index 71fbbbeb7..33ebf0f56 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -226,7 +226,7 @@ for V in spacelist @test dim(domain(S1)) <= trunc.howmany λ = minimum(minimum, values(LinearAlgebra.diag(S1))) - trunc = trunctol(λ - 10eps(λ)) + trunc = trunctol(; atol=λ - 10eps(λ)) U2, S2, Vᴴ2 = @constinferred svd_trunc(t; trunc) @test t * Vᴴ2' ≈ U2 * S2 @test isisometry(U2) @@ -243,7 +243,7 @@ for V in spacelist @test isisometry(Vᴴ3; side=:right) @test space(S3, 1) ≾ space(S2, 1) - trunc = truncerr(0.5) + trunc = truncerror(; atol=0.5) U4, S4, Vᴴ4 = @constinferred svd_trunc(t; trunc) @test t * Vᴴ4' ≈ U4 * S4 @test isisometry(U4) From d256ffb5034ffca46a095c4e60dcd7da6b1c9272 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 26 Sep 2025 10:28:48 -0400 Subject: [PATCH 107/126] rework AD rules --- .../TensorKitChainRulesCoreExt.jl | 5 +- .../factorizations.jl | 332 +++++-- src/tensors/factorizations/adjoint.jl | 8 +- src/tensors/factorizations/factorizations.jl | 4 + src/tensors/factorizations/pullbacks.jl | 13 + test/ad.jl | 849 +++++++++++------- test/factorizations.jl | 8 +- 7 files changed, 789 insertions(+), 430 deletions(-) create mode 100644 src/tensors/factorizations/pullbacks.jl diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl index 36bb4108d..272b47c6d 100644 --- a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -13,9 +13,10 @@ using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract using VectorInterface: promote_scale, promote_add using MatrixAlgebraKit -using MatrixAlgebraKit: TruncationStrategy, +using MatrixAlgebraKit: TruncationStrategy, TruncatedAlgorithm, svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!, - qr_compact_pullback!, lq_compact_pullback! + qr_compact_pullback!, lq_compact_pullback!, left_polar_pullback!, + right_polar_pullback! include("utility.jl") include("constructors.jl") diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index 35b78a52a..aac5f0af6 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,104 +1,280 @@ # Factorizations rules # -------------------- -for f in (:tsvd, :eig, :eigh) - f! = Symbol(f, :!) - f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!) - f_pullback = Symbol(f, :_pullback) - f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!) - @eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap; - trunc::TruncationStrategy=TensorKit.notrunc(), - kwargs...) - # TODO: I think we can use f! here without issues because we don't actually require - # the data of `t` anymore. - F = $f(t; trunc=TensorKit.notrunc(), kwargs...) - - if trunc != TensorKit.notrunc() && !isempty(blocksectors(t)) - F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc) - else - F′ = F - end +function ChainRulesCore.rrule(::typeof(MatrixAlgebraKit.copy_input), f, + t::AbstractTensorMap) + project = ProjectTo(t) + copy_input_pullback(Δt) = (NoTangent(), NoTangent(), project(unthunk(Δt))) + return MatrixAlgebraKit.copy_input(f, t), copy_input_pullback +end + +@non_differentiable MatrixAlgebraKit.initialize_output(f, t::AbstractTensorMap, args...) +@non_differentiable MatrixAlgebraKit.check_input(f, t::AbstractTensorMap, args...) - function $f_pullback(ΔF′) - ΔF = unthunk.(ΔF′) +for qr_f in (:qr_compact, :qr_full) + qr_f! = Symbol(qr_f, '!') + @eval function ChainRulesCore.rrule(::typeof($qr_f!), t::AbstractTensorMap, QR, alg) + tc = MatrixAlgebraKit.copy_input($qr_f, t) + QR = $(qr_f!)(tc, QR, alg) + function qr_pullback(ΔQR′) + ΔQR = unthunk.(ΔQR′) Δt = zerovector(t) - foreachblock(Δt) do c, (b,) - Fc = block.(F, Ref(c)) - ΔFc = block.(ΔF, Ref(c)) - $f_pullback!(b, Fc, ΔFc) - return nothing - end - return NoTangent(), Δt + MatrixAlgebraKit.qr_compact_pullback!(Δt, QR, ΔQR) + return NoTangent(), Δt, ZeroTangent(), NoTangent() + end + function qr_pullback(::Tuple{ZeroTangent,ZeroTangent}) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end - $f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent() + return QR, qr_pullback + end +end +function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg) + Q, R = qr_full(t, alg) + for (c, b) in blocks(t) + m, n = size(b) + copy!(block(N, c), view(block(Q, c), 1:m, (n + 1):m)) + end - return F′, $f_pullback + function qr_null_pullback(ΔN′) + ΔN = unthunk(ΔN′) + Δt = zerovector(t) + ΔQ = zerovector!(similar(Q, codomain(Q) ← fuse(codomain(Q)))) + foreachblock(ΔN) do c, (b,) + n = size(b, 2) + ΔQc = block(ΔQ, c) + return copy!(@view(ΔQc[:, (end - n + 1):end]), b) + end + ΔR = ZeroTangent() + MatrixAlgebraKit.qr_compact_pullback!(Δt, (Q, R), (ΔQ, ΔR)) + return NoTangent(), Δt, ZeroTangent(), NoTangent() end + qr_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + + return N, qr_null_pullback end -function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap) - U, S, V⁺ = tsvd(t) - s = diag(S) - project_t = ProjectTo(t) +for lq_f in (:lq_compact, :lq_full) + lq_f! = Symbol(lq_f, '!') + @eval function ChainRulesCore.rrule(::typeof($lq_f!), t::AbstractTensorMap, LQ, alg) + tc = MatrixAlgebraKit.copy_input($lq_f, t) + LQ = $(lq_f!)(tc, LQ, alg) + function lq_pullback(ΔLQ′) + ΔLQ = unthunk.(ΔLQ′) + Δt = zerovector(t) + MatrixAlgebraKit.lq_compact_pullback!(Δt, LQ, ΔLQ) + return NoTangent(), Δt, ZeroTangent(), NoTangent() + end + function lq_pullback(::Tuple{ZeroTangent,ZeroTangent}) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return LQ, lq_pullback + end +end +function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, alg) + L, Q = lq_full(t, alg) + for (c, b) in blocks(t) + m, n = size(b) + copy!(block(Nᴴ, c), view(block(Q, c), (m + 1):n, 1:n)) + end - function svdvals_pullback(Δs′) - Δs = unthunk(Δs′) - ΔS = diagm(codomain(S), domain(S), Δs) - return NoTangent(), project_t(U * ΔS * V⁺) + function lq_null_pullback(ΔNᴴ′) + ΔNᴴ = unthunk(ΔNᴴ′) + Δt = zerovector(t) + ΔQ = zerovector!(similar(Q, codomain(Q) ← fuse(codomain(Q)))) + foreachblock(ΔNᴴ) do c, (b,) + m = size(b, 1) + ΔQc = block(ΔQ, c) + return copy!(@view(ΔQc[(end - m + 1):end, :]), b) + end + ΔL = ZeroTangent() + MatrixAlgebraKit.lq_compact_pullback!(Δt, (L, Q), (ΔL, ΔQ)) + return NoTangent(), Δt, ZeroTangent(), NoTangent() end + lq_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - return s, svdvals_pullback + return Nᴴ, lq_null_pullback end -function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap; - sortby=nothing, kwargs...) - @assert sortby === nothing "only `sortby=nothing` is supported" - (D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...) - d = diag(D) - project_t = ProjectTo(t) - function eigvals_pullback(Δd′) - Δd = unthunk(Δd′) - ΔD = diagm(codomain(D), domain(D), Δd) - return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2]) +for eig in (:eig, :eigh) + eig_f = Symbol(eig, "_full") + eig_f! = Symbol(eig_f, "!") + eig_f_pb! = Symbol(eig, "_full_pullback!") + eig_pb = Symbol(eig, "_pullback") + @eval function ChainRulesCore.rrule(::typeof($eig_f!), t::AbstractTensorMap, DV, alg) + tc = copy_input($eig_f, t) + DV = $(eig_f!)(tc, DV, alg) + function $eig_pb(ΔDV) + Δt = zerovector(t) + MatrixAlgebraKit.$eig_f_pb!(Δt, DV, unthunk.(ΔDV)) + return NoTangent(), Δt, ZeroTangent(), NoTangent() + end + function $eig_pb(::Tuple{ZeroTangent,ZeroTangent}) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return DV, $eig_pb end +end - return d, eigvals_pullback +for svd_f in (:svd_compact, :svd_full) + svd_f! = Symbol(svd_f, "!") + @eval begin + function ChainRulesCore.rrule(::typeof($svd_f!), t::AbstractTensorMap, USVᴴ, alg) + tc = copy_input($svd_f, t) + USVᴴ = $(svd_f!)(tc, USVᴴ, alg) + function svd_pullback(ΔUSVᴴ) + Δt = zerovector(t) + MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ)) + return NoTangent(), Δt, ZeroTangent(), NoTangent() + end + function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return USVᴴ, svd_pullback + end + end end -function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) - alg isa MatrixAlgebraKit.LAPACK_HouseholderQR || - error("only `alg=QR()` and `alg=QRpos()` are supported") - QR = leftorth(t; alg) - function leftorth!_pullback(ΔQR′) - ΔQR = unthunk.(ΔQR′) +function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ, + alg::TruncatedAlgorithm) + tc = MatrixAlgebraKit.copy_input(svd_compact, t) + USVᴴ = svd_compact!(tc, USVᴴ, alg.alg) + function svd_trunc_pullback(ΔUSVᴴ) Δt = zerovector(t) - foreachblock(Δt) do c, (b,) - QRc = block.(QR, Ref(c)) - ΔQRc = block.(ΔQR, Ref(c)) - qr_compact_pullback!(b, QRc, ΔQRc) - return nothing - end - return NoTangent(), Δt + MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function svd_trunc_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end - leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() + return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, alg.trunc), svd_trunc_pullback +end - return QR, leftorth!_pullback +function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, alg) + tc = copy_input(left_polar, t) + WP = left_polar!(tc, WP, alg) + function left_polar_pullback(ΔWP) + Δt = zerovector(t) + MatrixAlgebraKit.left_polar_pullback!(Δt, WP, unthunk.(ΔWP)) + return NoTangent(), Δt, ZeroTangent(), NoTangent() + end + function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return WP, left_polar_pullback end -function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) - alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ || - error("only `alg=LQ()` and `alg=LQpos()` are supported") - LQ = rightorth(t; alg) - function rightorth!_pullback(ΔLQ′) - ΔLQ = unthunk(ΔLQ′) +function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, alg) + tc = copy_input(left_polar, t) + PWᴴ = right_polar!(Ac, PWᴴ, alg) + function right_polar_pullback(ΔPWᴴ) Δt = zerovector(t) - foreachblock(Δt) do c, (b,) - LQc = block.(LQ, Ref(c)) - ΔLQc = block.(ΔLQ, Ref(c)) - lq_compact_pullback!(b, LQc, ΔLQc) - return nothing - end - return NoTangent(), Δt + MatrixAlgebraKit.right_polar_pullback!(Δt, PWᴴ, unthunk.(ΔPWᴴ)) + return NoTangent(), Δt, ZeroTangent(), NoTangent() + end + function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end - rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() - return LQ, rightorth!_pullback + return PWᴴ, right_polar_pullback end + +# for f in (:tsvd, :eig, :eigh) +# f! = Symbol(f, :!) +# f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!) +# f_pullback = Symbol(f, :_pullback) +# f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!) +# @eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap; +# trunc::TruncationStrategy=TensorKit.notrunc(), +# kwargs...) +# # TODO: I think we can use f! here without issues because we don't actually require +# # the data of `t` anymore. +# F = $f(t; trunc=TensorKit.notrunc(), kwargs...) + +# if trunc != TensorKit.notrunc() && !isempty(blocksectors(t)) +# F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc) +# else +# F′ = F +# end + +# function $f_pullback(ΔF′) +# ΔF = unthunk.(ΔF′) +# Δt = zerovector(t) +# foreachblock(Δt) do c, (b,) +# Fc = block.(F, Ref(c)) +# ΔFc = block.(ΔF, Ref(c)) +# $f_pullback!(b, Fc, ΔFc) +# return nothing +# end +# return NoTangent(), Δt +# end +# $f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent() + +# return F′, $f_pullback +# end +# end + +# function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap) +# U, S, V⁺ = tsvd(t) +# s = diag(S) +# project_t = ProjectTo(t) + +# function svdvals_pullback(Δs′) +# Δs = unthunk(Δs′) +# ΔS = diagm(codomain(S), domain(S), Δs) +# return NoTangent(), project_t(U * ΔS * V⁺) +# end + +# return s, svdvals_pullback +# end + +# function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap; +# sortby=nothing, kwargs...) +# @assert sortby === nothing "only `sortby=nothing` is supported" +# (D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...) +# d = diag(D) +# project_t = ProjectTo(t) +# function eigvals_pullback(Δd′) +# Δd = unthunk(Δd′) +# ΔD = diagm(codomain(D), domain(D), Δd) +# return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2]) +# end + +# return d, eigvals_pullback +# end + +# function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) +# alg isa MatrixAlgebraKit.LAPACK_HouseholderQR || +# error("only `alg=QR()` and `alg=QRpos()` are supported") +# QR = leftorth(t; alg) +# function leftorth!_pullback(ΔQR′) +# ΔQR = unthunk.(ΔQR′) +# Δt = zerovector(t) +# foreachblock(Δt) do c, (b,) +# QRc = block.(QR, Ref(c)) +# ΔQRc = block.(ΔQR, Ref(c)) +# qr_compact_pullback!(b, QRc, ΔQRc) +# return nothing +# end +# return NoTangent(), Δt +# end +# leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() + +# return QR, leftorth!_pullback +# end + +# function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) +# alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ || +# error("only `alg=LQ()` and `alg=LQpos()` are supported") +# LQ = rightorth(t; alg) +# function rightorth!_pullback(ΔLQ′) +# ΔLQ = unthunk(ΔLQ′) +# Δt = zerovector(t) +# foreachblock(Δt) do c, (b,) +# LQc = block.(LQ, Ref(c)) +# ΔLQc = block.(ΔLQ, Ref(c)) +# lq_compact_pullback!(b, LQc, ΔLQc) +# return nothing +# end +# return NoTangent(), Δt +# end +# rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() +# return LQ, rightorth!_pullback +# end diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index df7189c54..b9f8acd52 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -2,10 +2,10 @@ # ---------------- # map algorithms to their adjoint counterpart # TODO: this probably belongs in MatrixAlgebraKit -_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.positive, alg.blocksize) -_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.positive, alg.blocksize) -_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.positive, alg.blocksize) -_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.positive, alg.blocksize) +_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.kwargs...) +_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.kwargs...) +_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.kwargs...) +_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.kwargs...) _adjoint(alg::PolarViaSVD) = PolarViaSVD(_adjoint(alg.svdalg)) _adjoint(alg::AbstractAlgorithm) = alg diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 6f447612a..14c5b80b7 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -42,12 +42,16 @@ import MatrixAlgebraKit: default_algorithm, left_orth!, right_orth!, left_null!, right_null!, truncate!, findtruncated, findtruncated_svd, diagview, isisometry +import MatrixAlgebraKit: qr_compact_pullback!, lq_compact_pullback!, svd_compact_pullback!, + left_polar_pullback!, right_polar_pullback!, eig_full_pullback!, + eigh_full_pullback! include("utility.jl") include("matrixalgebrakit.jl") include("truncation.jl") include("adjoint.jl") include("diagonal.jl") +include("pullbacks.jl") TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A) diff --git a/src/tensors/factorizations/pullbacks.jl b/src/tensors/factorizations/pullbacks.jl new file mode 100644 index 000000000..6d6028384 --- /dev/null +++ b/src/tensors/factorizations/pullbacks.jl @@ -0,0 +1,13 @@ +for pullback! in (:qr_compact_pullback!, :lq_compact_pullback!, + :svd_compact_pullback!, + :left_polar_pullback!, :right_polar_pullback!, + :eig_full_pullback!, :eigh_full_pullback!) + @eval function $pullback!(Δt::AbstractTensorMap, F, ΔF; kwargs...) + foreachblock(Δt) do c, (b,) + Fc = block.(F, Ref(c)) + ΔFc = block.(ΔF, Ref(c)) + return $pullback!(b, Fc, ΔFc; kwargs...) + end + return Δt + end +end diff --git a/test/ad.jl b/test/ad.jl index 4013846d0..52f7dee90 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -4,6 +4,8 @@ using FiniteDifferences: FiniteDifferences using Random using LinearAlgebra using Zygote +using MatrixAlgebraKit +using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ const _repartition = @static if isdefined(Base, :get_extension) Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition @@ -26,6 +28,7 @@ function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, for (c, b) in blocks(actual) ChainRulesTestUtils.@test_msg msg isapprox(b, block(expected, c); kwargs...) end + return nothing end # make sure that norms are computed correctly: @@ -50,8 +53,8 @@ function FiniteDifferences.to_vec(t::TensorKit.SectorDict) end # Float32 and finite differences don't mix well -precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1e-2 -precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-6 +precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1.0e-2 +precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1.0e-6 function randindextuple(N::Int, k::Int=rand(0:N)) @assert 0 ≤ k ≤ N @@ -59,47 +62,101 @@ function randindextuple(N::Int, k::Int=rand(0:N)) return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) end +function test_ad_rrule(f, args...; check_inferred=false, kwargs...) + test_rrule(Zygote.ZygoteRuleConfig(), f, args...; + rrule_f=rrule_via_ad, check_inferred, kwargs...) + return nothing +end + # rrules for functions that destroy inputs # ---------------------------------------- -function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), args...; kwargs...) - return ChainRulesCore.rrule(tsvd!, args...; kwargs...) -end -function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals), args...; kwargs...) - return ChainRulesCore.rrule(svdvals!, args...; kwargs...) -end -function ChainRulesCore.rrule(::typeof(TensorKit.eig), args...; kwargs...) - return ChainRulesCore.rrule(eig!, args...; kwargs...) -end -function ChainRulesCore.rrule(::typeof(TensorKit.eigh), args...; kwargs...) - return ChainRulesCore.rrule(eigh!, args...; kwargs...) +for f in + (:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, + :eig_full, :eigh_full, :svd_compact, :svd_trunc, :left_polar, :right_polar) + copy_f = Symbol(:copy_, f) + f! = Symbol(f, '!') + @eval begin + function $copy_f(input) + if $f === eigh_full + input = (input + input') / 2 + end + return $f(input) + end + function ChainRulesCore.rrule(::typeof($copy_f), input) + output = MatrixAlgebraKit.initialize_output($f!, input) + if $f === eigh_full + input = (input + input') / 2 + else + input = copy!(similar(input), input) + end + + output, pb = ChainRulesCore.rrule($f!, input, output) + return output, x -> (NoTangent(), pb(x)[2], NoTangent()) + end + end end -function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals), args...; kwargs...) - return ChainRulesCore.rrule(eigvals!, args...; kwargs...) + + +# Gauge fixing tangents +# --------------------- +function remove_qrgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:m, 1:minmn) + ΔQ2 = view(b, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + end + return ΔQ end -function ChainRulesCore.rrule(::typeof(TensorKit.leftorth), args...; kwargs...) - return ChainRulesCore.rrule(leftorth!, args...; kwargs...) + +function remove_lqgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:minmn, 1:n) + ΔQ2 = view(b, (minmn + 1):n, :) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + end + return ΔQ end -function ChainRulesCore.rrule(::typeof(TensorKit.rightorth), args...; kwargs...) - return ChainRulesCore.rrule(rightorth!, args...; kwargs...) +function remove_eiggauge_dependence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(D)) + gaugepart = V' * ΔV + for (c, b) in blocks(gaugepart) + Dc = block(D, c) + mask = abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol + b[mask] .= 0 + end + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV end - -# eigh′: make argument of eigh explicitly Hermitian -#--------------------------------------------------- -eigh′(t::AbstractTensorMap) = eigh(scale!(t + t', 1 / 2)) - -function ChainRulesCore.rrule(::typeof(eigh′), args...; kwargs...) - return ChainRulesCore.rrule(eigh!, args...; kwargs...) +function remove_eighgauge_dependence!(ΔV, D, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(D)) + gaugepart = V' * ΔV + gaugepart = (gaugepart - gaugepart') / 2 + for (c, b) in blocks(gaugepart) + Dc = block(D, c) + mask = abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol + b[mask] .= 0 + end + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV end -# complex-valued svd? -# ------------------- -function remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - # simple implementation, assumes no degeneracies or zeros in singular values +function remove_svdgauge_dependence!(ΔU, ΔV, U, S, V; + degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) gaugepart = U' * ΔU + V * ΔV' + gaugepart = (gaugepart - gaugepart') / 2 for (c, b) in blocks(gaugepart) - mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + Sc = block(S, c) + mask = abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol + b[mask] .= 0 end - return ΔU, ΔV + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ end # Tests @@ -107,365 +164,473 @@ end ChainRulesTestUtils.test_method_tables() -Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), - (ℂ[Z2Irrep](0 => 1, 1 => 1), - ℂ[Z2Irrep](0 => 1, 1 => 2)', - ℂ[Z2Irrep](0 => 3, 1 => 2)', - ℂ[Z2Irrep](0 => 2, 1 => 3), - ℂ[Z2Irrep](0 => 2, 1 => 2)), - (ℂ[FermionParity](0 => 1, 1 => 1), - ℂ[FermionParity](0 => 1, 1 => 2)', - ℂ[FermionParity](0 => 2, 1 => 2)', - ℂ[FermionParity](0 => 2, 1 => 3), - ℂ[FermionParity](0 => 2, 1 => 2)), - (ℂ[U1Irrep](0 => 2, 1 => 1, -1 => 1), - ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), - ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', - ℂ[U1Irrep](0 => 1, 1 => 1, -1 => 2), - ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 1)'), - (ℂ[SU2Irrep](0 => 2, 1 // 2 => 1), - ℂ[SU2Irrep](0 => 1, 1 => 1), - ℂ[SU2Irrep](1 // 2 => 1, 1 => 1)', - ℂ[SU2Irrep](1 // 2 => 2), - ℂ[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'), - (ℂ[FibonacciAnyon](:I => 1, :τ => 1), - ℂ[FibonacciAnyon](:I => 1, :τ => 2)', - ℂ[FibonacciAnyon](:I => 3, :τ => 2)', - ℂ[FibonacciAnyon](:I => 2, :τ => 3), - ℂ[FibonacciAnyon](:I => 2, :τ => 2))) - -@timedtestset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in - Vlist +spacelist = try + if ENV["CI"] == "true" + println("Detected running on CI") + if Sys.iswindows() + (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂) + elseif Sys.isapple() + (Vtr, Vℤ₃, VfU₁, VfSU₂) + else + (Vtr, VU₁, VCU₁, VfSU₂, Vfib) + end + else + (Vtr, Vℤ₃, VU₁, VfU₁, VSU₂, VfSU₂, Vfib) + end +catch + (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, Vfib) +end + +for V in spacelist + I = sectortype(eltype(V)) + Istr = TensorKit.type_repr(I) eltypes = isreal(sectortype(eltype(V))) ? (Float64, ComplexF64) : (ComplexF64,) symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + println("---------------------------------------") + println("Auto-diff with symmetry: $Istr") + println("---------------------------------------") + @timedtestset "AD with symmetry $Istr" verbose = true begin + V1, V2, V3, V4, V5 = V + W = V1 ⊗ V2 + false && @timedtestset "Basic utility" begin + T1 = randn(Float64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) + T2 = randn(ComplexF64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) + + P1 = ProjectTo(T1) + @test P1(T1) == T1 + @test P1(T2) == real(T2) + + test_rrule(copy, T1) + test_rrule(copy, T2) + test_rrule(TensorKit.copy_oftype, T1, ComplexF64) + if symmetricbraiding + test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, + ((3, 1), (2, 4))) + + test_rrule(convert, Array, T1) + test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); + fkwargs=(; tol=Inf)) + end - @timedtestset "Basic utility" begin - T1 = randn(Float64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) - T2 = randn(ComplexF64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) - - P1 = ProjectTo(T1) - @test P1(T1) == T1 - @test P1(T2) == real(T2) - - test_rrule(copy, T1) - test_rrule(copy, T2) - test_rrule(TensorKit.copy_oftype, T1, ComplexF64) - if symmetricbraiding - test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4))) - - test_rrule(convert, Array, T1) - test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); - fkwargs=(; tol=Inf)) + test_rrule(Base.getproperty, T1, :data) + test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space) + test_rrule(Base.getproperty, T2, :data) + test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space) end - test_rrule(Base.getproperty, T1, :data) - test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space) - test_rrule(Base.getproperty, T2, :data) - test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space) - end - - @timedtestset "Basic utility (DiagonalTensor)" begin - for v in V - rdim = reduceddim(v) - D1 = DiagonalTensorMap(randn(rdim), v) - D2 = DiagonalTensorMap(randn(rdim), v) - D = D1 + im * D2 - T1 = TensorMap(D1) - T2 = TensorMap(D2) - T = T1 + im * T2 - - # real -> real - P1 = ProjectTo(D1) - @test P1(D1) == D1 - @test P1(T1) == D1 - - # complex -> complex - P2 = ProjectTo(D) - @test P2(D) == D - @test P2(T) == D - - # real -> complex - @test P2(D1) == D1 + 0 * im * D1 - @test P2(T1) == D1 + 0 * im * D1 - - # complex -> real - @test P1(D) == D1 - @test P1(T) == D1 - - test_rrule(DiagonalTensorMap, D1.data, D1.domain) - test_rrule(DiagonalTensorMap, D.data, D.domain) - test_rrule(Base.getproperty, D, :data) - test_rrule(Base.getproperty, D1, :data) - - test_rrule(DiagonalTensorMap, rand!(T1)) - test_rrule(DiagonalTensorMap, randn!(T)) + false && @timedtestset "Basic utility (DiagonalTensor)" begin + for v in V + rdim = reduceddim(v) + D1 = DiagonalTensorMap(randn(rdim), v) + D2 = DiagonalTensorMap(randn(rdim), v) + D = D1 + im * D2 + T1 = TensorMap(D1) + T2 = TensorMap(D2) + T = T1 + im * T2 + + # real -> real + P1 = ProjectTo(D1) + @test P1(D1) == D1 + @test P1(T1) == D1 + + # complex -> complex + P2 = ProjectTo(D) + @test P2(D) == D + @test P2(T) == D + + # real -> complex + @test P2(D1) == D1 + 0 * im * D1 + @test P2(T1) == D1 + 0 * im * D1 + + # complex -> real + @test P1(D) == D1 + @test P1(T) == D1 + + test_rrule(DiagonalTensorMap, D1.data, D1.domain) + test_rrule(DiagonalTensorMap, D.data, D.domain) + test_rrule(Base.getproperty, D, :data) + test_rrule(Base.getproperty, D1, :data) + + test_rrule(DiagonalTensorMap, rand!(T1)) + test_rrule(DiagonalTensorMap, randn!(T)) + end end - end - @timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes - A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - B = randn(T, space(A)) + false && @timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + B = randn(T, space(A)) - test_rrule(real, A) - test_rrule(imag, A) + test_rrule(real, A) + test_rrule(imag, A) - test_rrule(+, A, B) - test_rrule(-, A) - test_rrule(-, A, B) + test_rrule(+, A, B) + test_rrule(-, A) + test_rrule(-, A, B) - α = randn(T) - test_rrule(*, α, A) - test_rrule(*, A, α) + α = randn(T) + test_rrule(*, α, A) + test_rrule(*, A, α) - C = randn(T, domain(A), codomain(A)) - test_rrule(*, A, C) + C = randn(T, domain(A), codomain(A)) + test_rrule(*, A, C) - symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4))) - test_rrule(twist, A, 1) - test_rrule(twist, A, [1, 3]) + symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4))) + test_rrule(twist, A, 1) + test_rrule(twist, A, [1, 3]) - test_rrule(flip, A, 1) - test_rrule(flip, A, [1, 3, 4]) + test_rrule(flip, A, 1) + test_rrule(flip, A, [1, 3, 4]) - D = randn(T, V[1] ⊗ V[2] ← V[3]) - E = randn(T, V[4] ← V[5]) - symmetricbraiding && test_rrule(⊗, D, E) - end - - @timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes - for i in 1:3 - E = randn(T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) - test_rrule(LinearAlgebra.tr, E) - test_rrule(exp, E; check_inferred=false) - test_rrule(inv, E) + D = randn(T, V[1] ⊗ V[2] ← V[3]) + E = randn(T, V[4] ← V[5]) + symmetricbraiding && test_rrule(⊗, D, E) end - A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - test_rrule(LinearAlgebra.adjoint, A) - test_rrule(LinearAlgebra.norm, A, 2) + false && @timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes + for i in 1:3 + E = randn(T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) + test_rrule(LinearAlgebra.tr, E) + test_rrule(exp, E; check_inferred=false) + test_rrule(inv, E) + end - B = randn(T, space(A)) - test_rrule(LinearAlgebra.dot, A, B) - end + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + test_rrule(LinearAlgebra.adjoint, A) + test_rrule(LinearAlgebra.norm, A, 2) - @timedtestset "Matrix functions ($T)" for T in eltypes - for f in (sqrt, exp) - check_inferred = false # !(T <: Real) # not type-stable for real functions - t1 = randn(T, V[1] ← V[1]) - t2 = randn(T, V[2] ← V[2]) - d = DiagonalTensorMap{T}(undef, V[1]) - (T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data) - d2 = DiagonalTensorMap{T}(undef, V[1]) - (T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data) - test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) - test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) - test_rrule(f, d; check_inferred, output_tangent=d2) + B = randn(T, space(A)) + test_rrule(LinearAlgebra.dot, A, B) end - end - symmetricbraiding && - @timedtestset "TensorOperations with scalartype $T" for T in eltypes - atol = precision(T) - rtol = precision(T) + @timedtestset "Matrix functions ($T)" for T in eltypes + for f in (sqrt, exp) + check_inferred = false # !(T <: Real) # not type-stable for real functions + t1 = randn(T, V[1] ← V[1]) + t2 = randn(T, V[2] ← V[2]) + d = DiagonalTensorMap{T}(undef, V[1]) + (T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data) + d2 = DiagonalTensorMap{T}(undef, V[1]) + (T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data) + test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) + test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) + test_rrule(f, d; check_inferred, output_tangent=d2) + end + end - @timedtestset "tensortrace!" begin - for _ in 1:5 - k1 = rand(0:3) - k2 = k1 == 3 ? 1 : rand(1:2) - V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) - V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + false && symmetricbraiding && + @timedtestset "TensorOperations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) - (_p, _q) = randindextuple(k1 + 2 * k2, k1) - p = _repartition(_p, rand(0:k1)) - q = _repartition(_q, k2) - ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) - A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + @timedtestset "tensortrace!" begin + for _ in 1:5 + k1 = rand(0:3) + k2 = k1 == 3 ? 1 : rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), + rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + for conjA in (false, true) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, + Val(false))) + test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol) + end + end + end + @timedtestset "tensoradd!" begin + A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5]) α = randn(T) β = randn(T) - for conjA in (false, true) - C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, - Val(false))) - test_rrule(tensortrace!, C, A, p, q, conjA, α, β; atol, rtol) - end - end - end - @timedtestset "tensoradd!" begin - A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5]) - α = randn(T) - β = randn(T) + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(length(V)) - # repeat a couple times to get some distribution of arrows - for _ in 1:5 - p = randindextuple(length(V)) + C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, + Val(false))) + test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol) - C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, - Val(false))) - test_rrule(tensoradd!, C1, A, p, false, α, β; atol, rtol) + C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, + Val(false))) + test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol) - C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false))) - test_rrule(tensoradd!, C2, A, p, true, α, β; atol, rtol) + A = rand(Bool) ? C1 : C2 + end + end - A = rand(Bool) ? C1 : C2 + @timedtestset "tensorcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), + dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init=one(V[1])) + + for conjA in (false, true), conjB in (false, true) + A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) + B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) + C = randn!(TensorOperations.tensoralloc_contract(T, A, pA, + conjA, + B, pB, conjB, + pAB, + Val(false))) + test_rrule(tensorcontract!, C, + A, pA, conjA, B, pB, conjB, pAB, + α, β; atol, rtol) + end + end + end + + @timedtestset "tensorscalar" begin + A = randn(T, ProductSpace{typeof(V[1]),0}()) + test_rrule(tensorscalar, A) end end - @timedtestset "tensorcontract!" begin - for _ in 1:5 - d = 0 - local V1, V2, V3 - # retry a couple times to make sure there are at least some nonzero elements - for _ in 1:10 - k1 = rand(0:3) - k2 = rand(0:2) - k3 = rand(0:2) - V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init=one(V[1])) - V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init=one(V[1])) - V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init=one(V[1])) - d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) - d > 0 && break + @timedtestset "Factorizations" begin + @testset "QR" begin + for T in eltypes, + t in (randn(T, W, W), randn(T, W, W)', + randn(T, W, V1), randn(T, V1, W), + randn(T, W, V1)', randn(T, V1, W)', + DiagonalTensorMap(randn(T, reduceddim(V1)), V1)) + + atol = rtol = precision(T) * dim(space(t)) + fkwargs = (; positive=true) # make FiniteDifferences happy + + test_ad_rrule(qr_compact, t; fkwargs, atol, rtol) + test_ad_rrule(first ∘ qr_compact, t; fkwargs, atol, rtol,) + test_ad_rrule(last ∘ qr_compact, t; fkwargs, atol, rtol,) + + # qr_full/qr_null requires being careful with gauges + Q, R = qr_full(t) + ΔQ = rand_tangent(Q) + ΔR = rand_tangent(R) + + if fuse(domain(t)) ≺ fuse(codomain(t)) + _, full_pb = Zygote.pullback(qr_full, t) + @test_logs (:warn, r"^`qr") match_mode = :any full_pb((ΔQ, ΔR)) end - ipA = randindextuple(length(V1) + length(V2)) - pA = _repartition(invperm(linearize(ipA)), length(V1)) - ipB = randindextuple(length(V2) + length(V3)) - pB = _repartition(invperm(linearize(ipB)), length(V2)) - pAB = randindextuple(length(V1) + length(V3)) - α = randn(T) - β = randn(T) - V2_conj = prod(conj, V2; init=one(V[1])) - - for conjA in (false, true), conjB in (false, true) - A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) - B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) - C = randn!(TensorOperations.tensoralloc_contract(T, A, pA, - conjA, - B, pB, conjB, pAB, - Val(false))) - test_rrule(tensorcontract!, C, - A, pA, conjA, B, pB, conjB, pAB, - α, β; atol, rtol) - end + remove_qrgauge_dependence!(ΔQ, t, Q) + + test_ad_rrule(qr_full, t; fkwargs, atol, rtol, output_tangent=(ΔQ, ΔR)) + test_ad_rrule(first ∘ qr_full, t; fkwargs, atol, rtol, + output_tangent=ΔQ) + test_ad_rrule(last ∘ qr_full, t; fkwargs, atol, rtol, output_tangent=ΔR) + + # TODO: figure out the following: + # N = qr_null(t) + # ΔN = Q * rand(T, domain(Q) ← domain(N)) + # test_ad_rrule(qr_null, t; fkwargs, atol, rtol, output_tangent=ΔN) + + # if fuse(domain(t)) ≺ fuse(codomain(t)) + # _, null_pb = Zygote.pullback(qr_null, t) + # @test_logs (:warn, r"^`qr") match_mode = :any null_pb(rand_tangent(N)) + # end end end - @timedtestset "tensorscalar" begin - A = randn(T, ProductSpace{typeof(V[1]),0}()) - test_rrule(tensorscalar, A) - end - end + @testset "LQ" begin + for T in eltypes, + t in (randn(T, W, W), randn(T, W, W)', + randn(T, W, V1), randn(T, V1, W), + randn(T, W, V1)', randn(T, V1, W)', + DiagonalTensorMap(randn(T, reduceddim(V1)), V1)) + + atol = rtol = precision(T) * dim(space(t)) + fkwargs = (; positive=true) # make FiniteDifferences happy + + test_ad_rrule(lq_compact, t; fkwargs, atol, rtol) + test_ad_rrule(first ∘ lq_compact, t; fkwargs, atol, rtol) + test_ad_rrule(last ∘ lq_compact, t; fkwargs, atol, rtol) + + # lq_full/lq_null requires being careful with gauges + L, Q = lq_full(t) + ΔQ = rand_tangent(Q) + ΔL = rand_tangent(L) + + if fuse(codomain(t)) ≺ fuse(domain(t)) + _, full_pb = Zygote.pullback(lq_full, t) + # broken due to typo in MAK + # @test_logs (:warn, r"^`lq") match_mode = :any full_pb((ΔL, ΔQ)) + end - @timedtestset "Factorizations with scalartype $T" for T in eltypes - A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - B = randn(T, space(A)') - C = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) - H = randn(T, V[3] ⊗ V[4] ← V[3] ⊗ V[4]) - H = (H + H') / 2 - atol = precision(T) - - for alg in (TensorKit.QR(), TensorKit.QRpos()) - test_rrule(leftorth, A; fkwargs=(; alg=alg), atol) - test_rrule(leftorth, B; fkwargs=(; alg=alg), atol) - test_rrule(leftorth, C; fkwargs=(; alg=alg), atol) - end + remove_lqgauge_dependence!(ΔQ, t, Q) - for alg in (TensorKit.LQ(), TensorKit.LQpos()) - test_rrule(rightorth, A; fkwargs=(; alg=alg), atol) - test_rrule(rightorth, B; fkwargs=(; alg=alg), atol) - test_rrule(rightorth, C; fkwargs=(; alg=alg), atol) - end + test_ad_rrule(lq_full, t; fkwargs, atol, rtol, output_tangent=(ΔL, ΔQ)) + test_ad_rrule(first ∘ lq_full, t; fkwargs, atol, rtol, + output_tangent=ΔL) + test_ad_rrule(last ∘ lq_full, t; fkwargs, atol, rtol, output_tangent=ΔQ) - let (D, V) = eig(C) - ΔD = randn(scalartype(D), space(D)) - ΔV = randn(scalartype(V), space(V)) - gaugepart = V' * ΔV - for (c, b) in blocks(gaugepart) - mul!(block(ΔV, c), inv(block(V, c))', Diagonal(diag(b)), -1, 1) - end - test_rrule(eig, C; atol, output_tangent=(ΔD, ΔV)) - end + # TODO: figure out the following + # Nᴴ = lq_null(t) + # ΔN = rand(T, codomain(Nᴴ) ← codomain(Q)) * Q + # test_ad_rrule(lq_null, t; fkwargs, atol, rtol, output_tangent=Nᴴ) - let (D, U) = eigh′(H) - ΔD = randn(scalartype(D), space(D)) - ΔU = randn(scalartype(U), space(U)) - if T <: Complex - gaugepart = U' * ΔU - for (c, b) in blocks(gaugepart) - mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + # if fuse(codomain(t)) ≺ fuse(domain(t)) + # _, null_pb = Zygote.pullback(lq_null, t) + # # broken due to typo in MAK + # # @test_logs (:warn, r"^`lq") match_mode = :any null_pb(rand_tangent(Nᴴ)) + # end end end - test_rrule(eigh′, H; atol, output_tangent=(ΔD, ΔU)) - end - let (U, S, V) = tsvd(A) - ΔU = randn(scalartype(U), space(U)) - ΔS = randn(scalartype(S), space(S)) - ΔV = randn(scalartype(V), space(V)) - if T <: Complex # remove gauge dependent components - gaugepart = U' * ΔU + V * ΔV' - for (c, b) in blocks(gaugepart) - mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + @testset "Eigenvalue decomposition" begin + for T in eltypes, + t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) + + atol = rtol = precision(T) * dim(space(t)) + + d, v = eig_full(t) + Δv = rand_tangent(v) + Δd = rand_tangent(d) + Δd2 = randn!(similar(d, space(d))) + remove_eiggauge_dependence!(Δv, d, v) + + test_ad_rrule(eig_full, t; output_tangent=(Δd, Δv), atol, rtol) + test_ad_rrule(first ∘ eig_full, t; output_tangent=Δd, atol, rtol) + test_ad_rrule(last ∘ eig_full, t; output_tangent=Δv, atol, rtol) + test_ad_rrule(eig_full, t; output_tangent=(Δd2, Δv), atol, rtol) + + add!(t, t') + d, v = eigh_full(t) + Δv = rand_tangent(v) + Δd = rand_tangent(d) + Δd2 = randn!(similar(d, space(d))) + remove_eighgauge_dependence!(Δv, d, v) + + test_ad_rrule(eigh_full, t; output_tangent=(Δd, Δv), atol, rtol) + test_ad_rrule(first ∘ eigh_full, t; output_tangent=Δd, atol, rtol) + test_ad_rrule(last ∘ eigh_full, t; output_tangent=Δv, atol, rtol) + test_ad_rrule(eigh_full, t; output_tangent=(Δd2, Δv), atol, rtol) end end - test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV)) - - allS = mapreduce(x -> diag(x[2]), vcat, blocks(S)) - truncval = (maximum(allS) + minimum(allS)) / 2 - U, S, V = tsvd(A; trunc=truncerr(truncval)) - ΔU = randn(scalartype(U), space(U)) - ΔS = randn(scalartype(S), space(S)) - ΔV = randn(scalartype(V), space(V)) - T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV), - fkwargs=(; trunc=truncerr(truncval))) - end - let (U, S, V) = tsvd(B) - ΔU = randn(scalartype(U), space(U)) - ΔS = randn(scalartype(S), space(S)) - ΔV = randn(scalartype(V), space(V)) - T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV)) - - Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2) - for (c, b) in blocks(S))) - - U, S, V = tsvd(B; trunc=truncspace(Vtrunc)) - ΔU = randn(scalartype(U), space(U)) - ΔS = randn(scalartype(S), space(S)) - ΔV = randn(scalartype(V), space(V)) - T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV), - fkwargs=(; trunc=truncspace(Vtrunc))) - end - - let (U, S, V) = tsvd(C) - ΔU = randn(scalartype(U), space(U)) - ΔS = randn(scalartype(S), space(S)) - ΔV = randn(scalartype(V), space(V)) - T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV)) - - c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) - trunc = truncdim(round(Int, 2 * dim(c))) - U, S, V = tsvd(C; trunc) - ΔU = randn(scalartype(U), space(U)) - ΔS = randn(scalartype(S), space(S)) - ΔV = randn(scalartype(V), space(V)) - T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV), fkwargs=(; trunc)) - end - - let D = LinearAlgebra.eigvals(C) - ΔD = diag(randn(complex(scalartype(C)), space(C))) - test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD, - fkwargs=(; sortby=nothing)) - end + @testset "Singular value decomposition" begin + for T in eltypes, + t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) + + atol = rtol = degeneracy_atol = precision(T) * dim(space(t)) + USVᴴ = svd_compact(t) + ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ) + ΔS2 = randn!(similar(ΔS, space(ΔS))) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol) + + test_ad_rrule(svd_full, t; output_tangent=(ΔU, ΔS, ΔVᴴ), atol, rtol) + test_ad_rrule(svd_full, t; output_tangent=(ΔU, ΔS2, ΔVᴴ), atol, rtol) + test_ad_rrule(svd_compact, t; output_tangent=(ΔU, ΔS, ΔVᴴ), atol, rtol) + test_ad_rrule(svd_compact, t; output_tangent=(ΔU, ΔS2, ΔVᴴ), atol, rtol) + + trunc = truncrank(min(dim(domain(t)), dim(codomain(t))) ÷ 2) + USVᴴ′ = svd_trunc(t; trunc) + ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ′) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol) + + test_ad_rrule(svd_trunc, t; + fkwargs=(; trunc), output_tangent=(ΔU, ΔS, ΔVᴴ), atol, + rtol) + end + end - let S = LinearAlgebra.svdvals(C) - ΔS = diag(randn(real(scalartype(C)), space(C))) - test_rrule(LinearAlgebra.svdvals, C; atol, output_tangent=ΔS) + # let (U, S, V) = tsvd(A) + # ΔU = randn(scalartype(U), space(U)) + # ΔS = randn(scalartype(S), space(S)) + # ΔV = randn(scalartype(V), space(V)) + # if T <: Complex # remove gauge dependent components + # gaugepart = U' * ΔU + V * ΔV' + # for (c, b) in blocks(gaugepart) + # mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) + # end + # end + # test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV)) + + # allS = mapreduce(x -> diag(x[2]), vcat, blocks(S)) + # truncval = (maximum(allS) + minimum(allS)) / 2 + # U, S, V = tsvd(A; trunc=truncerror(; atol=truncval)) + # ΔU = randn(scalartype(U), space(U)) + # ΔS = randn(scalartype(S), space(S)) + # ΔV = randn(scalartype(V), space(V)) + # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) + # test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV), + # fkwargs=(; trunc=truncerror(; atol=truncval))) + # end + + # let (U, S, V) = tsvd(B) + # ΔU = randn(scalartype(U), space(U)) + # ΔS = randn(scalartype(S), space(S)) + # ΔV = randn(scalartype(V), space(V)) + # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) + # test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV)) + + # Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2) + # for (c, b) in blocks(S))) + + # U, S, V = tsvd(B; trunc=truncspace(Vtrunc)) + # ΔU = randn(scalartype(U), space(U)) + # ΔS = randn(scalartype(S), space(S)) + # ΔV = randn(scalartype(V), space(V)) + # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) + # test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV), + # fkwargs=(; trunc=truncspace(Vtrunc))) + # end + + # let (U, S, V) = tsvd(C) + # ΔU = randn(scalartype(U), space(U)) + # ΔS = randn(scalartype(S), space(S)) + # ΔV = randn(scalartype(V), space(V)) + # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) + # test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV)) + + # c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) + # trunc = truncrank(round(Int, 2 * dim(c))) + # U, S, V = tsvd(C; trunc) + # ΔU = randn(scalartype(U), space(U)) + # ΔS = randn(scalartype(S), space(S)) + # ΔV = randn(scalartype(V), space(V)) + # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) + # test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV), fkwargs=(; trunc)) + # end + + # let D = LinearAlgebra.eigvals(C) + # ΔD = diag(randn(complex(scalartype(C)), space(C))) + # test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD, + # fkwargs=(; sortby=nothing)) + # end + + # let S = LinearAlgebra.svdvals(C) + # ΔS = diag(randn(real(scalartype(C)), space(C))) + # test_rrule(LinearAlgebra.svdvals, C; atol, output_tangent=ΔS) + # end end end end diff --git a/test/factorizations.jl b/test/factorizations.jl index 33ebf0f56..4022ff3e6 100644 --- a/test/factorizations.jl +++ b/test/factorizations.jl @@ -16,11 +16,12 @@ catch end eltypes = (Float32, ComplexF64) + for V in spacelist I = sectortype(first(V)) Istr = TensorKit.type_repr(I) println("---------------------------------------") - println("Tensors with symmetry: $Istr") + println("Factorizations with symmetry: $Istr") println("---------------------------------------") @timedtestset "Factorizations with symmetry: $Istr" verbose = true begin V1, V2, V3, V4, V5 = V @@ -254,9 +255,8 @@ for V in spacelist @testset "Eigenvalue decomposition" begin for T in eltypes, - t in - (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', - DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) + t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', + DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) d, v = @constinferred eig_full(t) @test t * v ≈ v * d From ff9e391b729a2646e20b3d72cc3731e9d6e777c2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 29 Sep 2025 17:04:15 -0400 Subject: [PATCH 108/126] Adapt to MatrixAlgebraKit v0.4.1 --- Project.toml | 2 +- .../TensorKitChainRulesCoreExt.jl | 6 +- .../factorizations.jl | 115 ++-------- src/tensors/factorizations/factorizations.jl | 8 +- src/tensors/factorizations/pullbacks.jl | 41 +++- src/tensors/factorizations/truncation.jl | 107 ++++++---- test/ad.jl | 198 +++++++----------- test/runtests.jl | 10 +- 8 files changed, 215 insertions(+), 272 deletions(-) diff --git a/Project.toml b/Project.toml index e9790c851..ff1dd832a 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.4.0" +MatrixAlgebraKit = "0.4.1" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl index 272b47c6d..727693816 100644 --- a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -14,9 +14,9 @@ using VectorInterface: promote_scale, promote_add using MatrixAlgebraKit using MatrixAlgebraKit: TruncationStrategy, TruncatedAlgorithm, - svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!, - qr_compact_pullback!, lq_compact_pullback!, left_polar_pullback!, - right_polar_pullback! + svd_pullback!, eig_pullback!, eigh_pullback!, + qr_compact_pullback!, lq_compact_pullback!, + left_polar_pullback!, right_polar_pullback! include("utility.jl") include("constructors.jl") diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index aac5f0af6..2746f5444 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -18,7 +18,7 @@ for qr_f in (:qr_compact, :qr_full) function qr_pullback(ΔQR′) ΔQR = unthunk.(ΔQR′) Δt = zerovector(t) - MatrixAlgebraKit.qr_compact_pullback!(Δt, QR, ΔQR) + MatrixAlgebraKit.qr_compact_pullback!(Δt, t, QR, ΔQR) return NoTangent(), Δt, ZeroTangent(), NoTangent() end function qr_pullback(::Tuple{ZeroTangent,ZeroTangent}) @@ -44,7 +44,7 @@ function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg) return copy!(@view(ΔQc[:, (end - n + 1):end]), b) end ΔR = ZeroTangent() - MatrixAlgebraKit.qr_compact_pullback!(Δt, (Q, R), (ΔQ, ΔR)) + MatrixAlgebraKit.qr_compact_pullback!(Δt, t, (Q, R), (ΔQ, ΔR)) return NoTangent(), Δt, ZeroTangent(), NoTangent() end qr_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() @@ -60,7 +60,7 @@ for lq_f in (:lq_compact, :lq_full) function lq_pullback(ΔLQ′) ΔLQ = unthunk.(ΔLQ′) Δt = zerovector(t) - MatrixAlgebraKit.lq_compact_pullback!(Δt, LQ, ΔLQ) + MatrixAlgebraKit.lq_compact_pullback!(Δt, t, LQ, ΔLQ) return NoTangent(), Δt, ZeroTangent(), NoTangent() end function lq_pullback(::Tuple{ZeroTangent,ZeroTangent}) @@ -86,7 +86,7 @@ function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, al return copy!(@view(ΔQc[(end - m + 1):end, :]), b) end ΔL = ZeroTangent() - MatrixAlgebraKit.lq_compact_pullback!(Δt, (L, Q), (ΔL, ΔQ)) + MatrixAlgebraKit.lq_compact_pullback!(Δt, t, (L, Q), (ΔL, ΔQ)) return NoTangent(), Δt, ZeroTangent(), NoTangent() end lq_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() @@ -97,14 +97,14 @@ end for eig in (:eig, :eigh) eig_f = Symbol(eig, "_full") eig_f! = Symbol(eig_f, "!") - eig_f_pb! = Symbol(eig, "_full_pullback!") + eig_f_pb! = Symbol(eig, "_pullback!") eig_pb = Symbol(eig, "_pullback") @eval function ChainRulesCore.rrule(::typeof($eig_f!), t::AbstractTensorMap, DV, alg) - tc = copy_input($eig_f, t) + tc = MatrixAlgebraKit.copy_input($eig_f, t) DV = $(eig_f!)(tc, DV, alg) function $eig_pb(ΔDV) Δt = zerovector(t) - MatrixAlgebraKit.$eig_f_pb!(Δt, DV, unthunk.(ΔDV)) + MatrixAlgebraKit.$eig_f_pb!(Δt, t, DV, unthunk.(ΔDV)) return NoTangent(), Δt, ZeroTangent(), NoTangent() end function $eig_pb(::Tuple{ZeroTangent,ZeroTangent}) @@ -118,11 +118,11 @@ for svd_f in (:svd_compact, :svd_full) svd_f! = Symbol(svd_f, "!") @eval begin function ChainRulesCore.rrule(::typeof($svd_f!), t::AbstractTensorMap, USVᴴ, alg) - tc = copy_input($svd_f, t) + tc = MatrixAlgebraKit.copy_input($svd_f, t) USVᴴ = $(svd_f!)(tc, USVᴴ, alg) function svd_pullback(ΔUSVᴴ) Δt = zerovector(t) - MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ)) + MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ)) return NoTangent(), Δt, ZeroTangent(), NoTangent() end function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) @@ -137,23 +137,28 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ alg::TruncatedAlgorithm) tc = MatrixAlgebraKit.copy_input(svd_compact, t) USVᴴ = svd_compact!(tc, USVᴴ, alg.alg) + USVᴴ_trunc, ind = TensorKit.Factorizations.truncate(svd_trunc!, USVᴴ, alg.trunc) + svd_trunc_pullback = _make_svd_trunc_pullback(t, USVᴴ, ind) + return USVᴴ_trunc, svd_trunc_pullback +end +function _make_svd_trunc_pullback(t::AbstractTensorMap, USVᴴ, ind) function svd_trunc_pullback(ΔUSVᴴ) Δt = zerovector(t) - MatrixAlgebraKit.svd_compact_pullback!(Δt, USVᴴ, unthunk.(ΔUSVᴴ)) - return NoTangent(), ΔA, ZeroTangent(), NoTangent() + MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ), ind) + return NoTangent(), Δt, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) + function svd_trunc_pullback(::NTuple{3,ZeroTangent}) return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end - return MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, alg.trunc), svd_trunc_pullback + return svd_trunc_pullback end function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, alg) - tc = copy_input(left_polar, t) + tc = MatrixAlgebraKit.copy_input(left_polar, t) WP = left_polar!(tc, WP, alg) function left_polar_pullback(ΔWP) Δt = zerovector(t) - MatrixAlgebraKit.left_polar_pullback!(Δt, WP, unthunk.(ΔWP)) + MatrixAlgebraKit.left_polar_pullback!(Δt, t, WP, unthunk.(ΔWP)) return NoTangent(), Δt, ZeroTangent(), NoTangent() end function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) @@ -163,11 +168,11 @@ function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, a end function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, alg) - tc = copy_input(left_polar, t) - PWᴴ = right_polar!(Ac, PWᴴ, alg) + tc = MatrixAlgebraKit.copy_input(left_polar, t) + PWᴴ = right_polar!(tc, PWᴴ, alg) function right_polar_pullback(ΔPWᴴ) Δt = zerovector(t) - MatrixAlgebraKit.right_polar_pullback!(Δt, PWᴴ, unthunk.(ΔPWᴴ)) + MatrixAlgebraKit.right_polar_pullback!(Δt, t, PWᴴ, unthunk.(ΔPWᴴ)) return NoTangent(), Δt, ZeroTangent(), NoTangent() end function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) @@ -176,41 +181,6 @@ function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PW return PWᴴ, right_polar_pullback end -# for f in (:tsvd, :eig, :eigh) -# f! = Symbol(f, :!) -# f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!) -# f_pullback = Symbol(f, :_pullback) -# f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!) -# @eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap; -# trunc::TruncationStrategy=TensorKit.notrunc(), -# kwargs...) -# # TODO: I think we can use f! here without issues because we don't actually require -# # the data of `t` anymore. -# F = $f(t; trunc=TensorKit.notrunc(), kwargs...) - -# if trunc != TensorKit.notrunc() && !isempty(blocksectors(t)) -# F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc) -# else -# F′ = F -# end - -# function $f_pullback(ΔF′) -# ΔF = unthunk.(ΔF′) -# Δt = zerovector(t) -# foreachblock(Δt) do c, (b,) -# Fc = block.(F, Ref(c)) -# ΔFc = block.(ΔF, Ref(c)) -# $f_pullback!(b, Fc, ΔFc) -# return nothing -# end -# return NoTangent(), Δt -# end -# $f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent() - -# return F′, $f_pullback -# end -# end - # function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap) # U, S, V⁺ = tsvd(t) # s = diag(S) @@ -239,42 +209,3 @@ end # return d, eigvals_pullback # end - -# function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos()) -# alg isa MatrixAlgebraKit.LAPACK_HouseholderQR || -# error("only `alg=QR()` and `alg=QRpos()` are supported") -# QR = leftorth(t; alg) -# function leftorth!_pullback(ΔQR′) -# ΔQR = unthunk.(ΔQR′) -# Δt = zerovector(t) -# foreachblock(Δt) do c, (b,) -# QRc = block.(QR, Ref(c)) -# ΔQRc = block.(ΔQR, Ref(c)) -# qr_compact_pullback!(b, QRc, ΔQRc) -# return nothing -# end -# return NoTangent(), Δt -# end -# leftorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() - -# return QR, leftorth!_pullback -# end - -# function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos()) -# alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ || -# error("only `alg=LQ()` and `alg=LQpos()` are supported") -# LQ = rightorth(t; alg) -# function rightorth!_pullback(ΔLQ′) -# ΔLQ = unthunk(ΔLQ′) -# Δt = zerovector(t) -# foreachblock(Δt) do c, (b,) -# LQc = block.(LQ, Ref(c)) -# ΔLQc = block.(ΔLQ, Ref(c)) -# lq_compact_pullback!(b, LQc, ΔLQc) -# return nothing -# end -# return NoTangent(), Δt -# end -# rightorth!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent() -# return LQ, rightorth!_pullback -# end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 14c5b80b7..b6a2355d1 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -42,9 +42,11 @@ import MatrixAlgebraKit: default_algorithm, left_orth!, right_orth!, left_null!, right_null!, truncate!, findtruncated, findtruncated_svd, diagview, isisometry -import MatrixAlgebraKit: qr_compact_pullback!, lq_compact_pullback!, svd_compact_pullback!, - left_polar_pullback!, right_polar_pullback!, eig_full_pullback!, - eigh_full_pullback! +using MatrixAlgebraKit: qr_compact_pullback!, lq_compact_pullback!, + svd_pullback!, svd_trunc_pullback!, + eig_pullback!, eig_trunc_pullback!, + eigh_pullback!, eigh_trunc_pullback!, + left_polar_pullback!, right_polar_pullback! include("utility.jl") include("matrixalgebrakit.jl") diff --git a/src/tensors/factorizations/pullbacks.jl b/src/tensors/factorizations/pullbacks.jl index 6d6028384..daa54f46d 100644 --- a/src/tensors/factorizations/pullbacks.jl +++ b/src/tensors/factorizations/pullbacks.jl @@ -1,12 +1,41 @@ for pullback! in (:qr_compact_pullback!, :lq_compact_pullback!, - :svd_compact_pullback!, - :left_polar_pullback!, :right_polar_pullback!, - :eig_full_pullback!, :eigh_full_pullback!) - @eval function $pullback!(Δt::AbstractTensorMap, F, ΔF; kwargs...) - foreachblock(Δt) do c, (b,) + :left_polar_pullback!, :right_polar_pullback!) + @eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, + F, ΔF; kwargs...) + foreachblock(Δt, t) do c, (Δb, b) Fc = block.(F, Ref(c)) ΔFc = block.(ΔF, Ref(c)) - return $pullback!(b, Fc, ΔFc; kwargs...) + return $pullback!(Δb, b, Fc, ΔFc; kwargs...) + end + return Δt + end +end + +_notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t)) + +for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!) + @eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, + F, ΔF, inds=_notrunc_ind(t); + kwargs...) + for (c, ind) in inds + Δb = block(Δt, c) + b = block(t, c) + Fc = block.(F, Ref(c)) + ΔFc = block.(ΔF, Ref(c)) + $pullback!(Δb, b, Fc, ΔFc, ind; kwargs...) + end + return Δt + end +end + +for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_pullback!) + @eval function MatrixAlgebraKit.$pullback_trunc!(Δt::AbstractTensorMap, + t::AbstractTensorMap, + F, ΔF; kwargs...) + foreachblock(Δt, t) do c, (Δb, b) + Fc = block.(F, Ref(c)) + ΔFc = block.(ΔF, Ref(c)) + return $pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...) end return Δt end diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 5d5178470..667a5f23d 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -31,73 +31,94 @@ end # truncate! # --------- -function truncate!(::typeof(svd_trunc!), - (U, S, Vᴴ)::Tuple{AbstractTensorMap,AbstractTensorMap,AbstractTensorMap}, - strategy::TruncationStrategy) - ind = findtruncated_svd(diagview(S), strategy) - V_truncated = spacetype(S)(c => length(I) for (c, I) in ind) +_blocklength(d::Integer, ind) = _blocklength(Base.OneTo(d), ind) +_blocklength(ax, ind) = length(ax[ind]) +function truncate_space(V::ElementarySpace, inds) + return spacetype(V)(c => _blocklength(dim(V, c), ind) for (c, ind) in inds) +end - Ũ = similar(U, codomain(U) ← V_truncated) - for (c, b) in blocks(Ũ) - I = get(ind, c, nothing) +function truncate_domain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, inds) + for (c, b) in blocks(tdst) + I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, @view(block(U, c)[:, I])) + copy!(b, @view(block(tsrc, c)[:, I])) end - - S̃ = DiagonalTensorMap{scalartype(S)}(undef, V_truncated) - for (c, b) in blocks(S̃) - I = get(ind, c, nothing) + return tdst +end +function truncate_codomain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, inds) + for (c, b) in blocks(tdst) + I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b.diag, @view(block(S, c).diag[I])) + copy!(b, @view(block(tsrc, c)[I, :])) end - - Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) - for (c, b) in blocks(Ṽᴴ) - I = get(ind, c, nothing) + return tdst +end +function truncate_diagonal!(Ddst::DiagonalTensorMap, Dsrc::DiagonalTensorMap, inds) + for (c, b) in blocks(Ddst) + I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, @view(block(Vᴴ, c)[I, :])) + copy!(diagview(b), @view(diagview(block(Dsrc, c))[I])) end + return Ddst +end - return Ũ, S̃, Ṽᴴ +function truncate(::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3,AbstractTensorMap}, + strategy::TruncationStrategy) + ind = findtruncated_svd(diagview(S), strategy) + V_truncated = truncate_space(space(S, 1), ind) + + Ũ = similar(U, codomain(U) ← V_truncated) + truncate_domain!(Ũ, U, ind) + S̃ = DiagonalTensorMap{scalartype(S)}(undef, V_truncated) + truncate_diagonal!(S̃, S, ind) + Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) + truncate_codomain!(Ṽᴴ, Vᴴ, ind) + + return (Ũ, S̃, Ṽᴴ), ind +end +function truncate!(::typeof(svd_trunc!), USVᴴ::NTuple{3,AbstractTensorMap}, + strategy::TruncationStrategy) + USVᴴ_trunc, _ = truncate(svd_trunc!, USVᴴ, strategy) + return USVᴴ_trunc end -function truncate!(::typeof(left_null!), - (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, - strategy::MatrixAlgebraKit.TruncationStrategy) +function truncate(::typeof(left_null!), + (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, + strategy::MatrixAlgebraKit.TruncationStrategy) extended_S = SectorDict(c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) for (c, b) in blocks(S)) ind = findtruncated(extended_S, strategy) - V_truncated = spacetype(S)(c => length(axes(b, 1)[ind[c]]) for (c, b) in blocks(S)) + V_truncated = truncate_space(space(S, 1), ind) Ũ = similar(U, codomain(U) ← V_truncated) - for (c, b) in blocks(Ũ) - copy!(b, @view(block(U, c)[:, ind[c]])) - end - return Ũ + truncate_domain!(Ũ, U, ind) + return Ũ, ind +end +function truncate!(::typeof(left_null!), US::NTuple{2,AbstractTensorMap}, + strategy::TruncationStrategy) + U_trunc, _ = truncate(left_null!, US, strategy) + return U_trunc end for f! in (:eig_trunc!, :eigh_trunc!) - @eval function truncate!(::typeof($f!), - (D, V)::Tuple{AbstractTensorMap,AbstractTensorMap}, - strategy::TruncationStrategy) + @eval function truncate(::typeof($f!), + (D, V)::Tuple{DiagonalTensorMap,AbstractTensorMap}, + strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) - for (c, b) in blocks(D̃) - I = get(ind, c, nothing) - @assert !isnothing(I) - copy!(b.diag, @view(block(D, c).diag[I])) - end + truncate_diagonal!(D̃, D, ind) Ṽ = similar(V, codomain(V) ← V_truncated) - for (c, b) in blocks(Ṽ) - I = get(ind, c, nothing) - @assert !isnothing(I) - copy!(b, @view(block(V, c)[:, I])) - end + truncate_domain!(Ṽ, V, ind) - return D̃, Ṽ + return (D̃, Ṽ), ind + end + @eval function truncate!(::typeof($f!), DV::Tuple{DiagonalTensorMap,AbstractTensorMap}, + strategy::TruncationStrategy) + DV_trunc, _ = truncate($f!, DV, strategy) + return DV_trunc end end @@ -141,7 +162,7 @@ function findtruncated_svd(values::SectorDict, strategy::TruncationStrategy) end function findtruncated(values::SectorDict, ::NoTruncation) - return SectorDict(c => Base.OneTo(length(b)) for (c, b) in values) + return SectorDict(c => Colon() for (c, b) in values) end function findtruncated(values::SectorDict, strategy::TruncationByOrder) diff --git a/test/ad.jl b/test/ad.jl index 52f7dee90..883b24578 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -5,7 +5,7 @@ using Random using LinearAlgebra using Zygote using MatrixAlgebraKit -using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ +using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ, diagview const _repartition = @static if isdefined(Base, :get_extension) Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition @@ -68,35 +68,6 @@ function test_ad_rrule(f, args...; check_inferred=false, kwargs...) return nothing end -# rrules for functions that destroy inputs -# ---------------------------------------- -for f in - (:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, - :eig_full, :eigh_full, :svd_compact, :svd_trunc, :left_polar, :right_polar) - copy_f = Symbol(:copy_, f) - f! = Symbol(f, '!') - @eval begin - function $copy_f(input) - if $f === eigh_full - input = (input + input') / 2 - end - return $f(input) - end - function ChainRulesCore.rrule(::typeof($copy_f), input) - output = MatrixAlgebraKit.initialize_output($f!, input) - if $f === eigh_full - input = (input + input') / 2 - else - input = copy!(similar(input), input) - end - - output, pb = ChainRulesCore.rrule($f!, input, output) - return output, x -> (NoTangent(), pb(x)[2], NoTangent()) - end - end -end - - # Gauge fixing tangents # --------------------- function remove_qrgauge_dependence!(ΔQ, t, Q) @@ -126,9 +97,13 @@ function remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(D)) gaugepart = V' * ΔV for (c, b) in blocks(gaugepart) - Dc = block(D, c) - mask = abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol - b[mask] .= 0 + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end end mul!(ΔV, V / (V' * V), gaugepart, -1, 1) return ΔV @@ -138,27 +113,36 @@ function remove_eighgauge_dependence!(ΔV, D, V; gaugepart = V' * ΔV gaugepart = (gaugepart - gaugepart') / 2 for (c, b) in blocks(gaugepart) - Dc = block(D, c) - mask = abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol - b[mask] .= 0 + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end end mul!(ΔV, V / (V' * V), gaugepart, -1, 1) return ΔV end - -function remove_svdgauge_dependence!(ΔU, ΔV, U, S, V; +function remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) - gaugepart = U' * ΔU + V * ΔV' + gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' gaugepart = (gaugepart - gaugepart') / 2 for (c, b) in blocks(gaugepart) - Sc = block(S, c) - mask = abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol - b[mask] .= 0 + Sd = diagview(block(S, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0) + end end mul!(ΔU, U, gaugepart, -1, 1) return ΔU, ΔVᴴ end +project_hermitian(A) = (A + A') / 2 + # Tests # ----- @@ -192,7 +176,7 @@ for V in spacelist @timedtestset "AD with symmetry $Istr" verbose = true begin V1, V2, V3, V4, V5 = V W = V1 ⊗ V2 - false && @timedtestset "Basic utility" begin + @timedtestset "Basic utility" begin T1 = randn(Float64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) T2 = randn(ComplexF64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) @@ -218,7 +202,7 @@ for V in spacelist test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space) end - false && @timedtestset "Basic utility (DiagonalTensor)" begin + @timedtestset "Basic utility (DiagonalTensor)" begin for v in V rdim = reduceddim(v) D1 = DiagonalTensorMap(randn(rdim), v) @@ -256,7 +240,7 @@ for V in spacelist end end - false && @timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes + @timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = randn(T, space(A)) @@ -286,7 +270,7 @@ for V in spacelist symmetricbraiding && test_rrule(⊗, D, E) end - false && @timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes + @timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes for i in 1:3 E = randn(T, ⊗(V[1:i]...) ← ⊗(V[1:i]...)) test_rrule(LinearAlgebra.tr, E) @@ -317,7 +301,7 @@ for V in spacelist end end - false && symmetricbraiding && + symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes atol = precision(T) rtol = precision(T) @@ -528,17 +512,21 @@ for V in spacelist Δd2 = randn!(similar(d, space(d))) remove_eighgauge_dependence!(Δv, d, v) - test_ad_rrule(eigh_full, t; output_tangent=(Δd, Δv), atol, rtol) - test_ad_rrule(first ∘ eigh_full, t; output_tangent=Δd, atol, rtol) - test_ad_rrule(last ∘ eigh_full, t; output_tangent=Δv, atol, rtol) - test_ad_rrule(eigh_full, t; output_tangent=(Δd2, Δv), atol, rtol) + # necessary for FiniteDifferences to not complain + eigh_full′ = eigh_full ∘ project_hermitian + + test_ad_rrule(eigh_full′, t; output_tangent=(Δd, Δv), atol, rtol) + test_ad_rrule(first ∘ eigh_full′, t; output_tangent=Δd, atol, rtol) + test_ad_rrule(last ∘ eigh_full′, t; output_tangent=Δv, atol, rtol) + test_ad_rrule(eigh_full′, t; output_tangent=(Δd2, Δv), atol, rtol) end end @testset "Singular value decomposition" begin for T in eltypes, - t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', - DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) + t in (randn(T, V1, V1), randn(T, W, W), randn(T, W, W)) + # TODO: fix diagonaltensormap case + # DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) atol = rtol = degeneracy_atol = precision(T) * dim(space(t)) USVᴴ = svd_compact(t) @@ -551,76 +539,48 @@ for V in spacelist test_ad_rrule(svd_compact, t; output_tangent=(ΔU, ΔS, ΔVᴴ), atol, rtol) test_ad_rrule(svd_compact, t; output_tangent=(ΔU, ΔS2, ΔVᴴ), atol, rtol) - trunc = truncrank(min(dim(domain(t)), dim(codomain(t))) ÷ 2) - USVᴴ′ = svd_trunc(t; trunc) - ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ′) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol) - + # TODO: I'm not sure how to properly test with spaces that might change + # with the finite-difference methods, as then the jacobian is ill-defined. + + trunc = truncrank(round(Int, min(dim(domain(t)), dim(codomain(t))) ÷ 2)) + USVᴴ_trunc = svd_trunc(t; trunc) + ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) + remove_svdgauge_dependence!(ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], + USVᴴ_trunc...; + degeneracy_atol) + # test_ad_rrule(svd_trunc, t; + # fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol) + + trunc = truncspace(space(USVᴴ_trunc[2], 1)) + USVᴴ_trunc = svd_trunc(t; trunc) + ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) + remove_svdgauge_dependence!(ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], + USVᴴ_trunc...; + degeneracy_atol) + test_ad_rrule(svd_trunc, t; + fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol) + + # ϵ = norm(*(USVᴴ_trunc...) - t) + # trunc = truncerror(; atol=ϵ) + # USVᴴ_trunc = svd_trunc(t; trunc) + # ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) + # remove_svdgauge_dependence!(ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], USVᴴ_trunc...; + # degeneracy_atol) + # test_ad_rrule(svd_trunc, t; + # fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol) + + tol = minimum(((c, b),) -> minimum(diagview(b)), blocks(USVᴴ_trunc[2])) + trunc = trunctol(; atol=10 * tol) + USVᴴ_trunc = svd_trunc(t; trunc) + ΔUSVᴴ_trunc = rand_tangent.(USVᴴ_trunc) + remove_svdgauge_dependence!(ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], + USVᴴ_trunc...; + degeneracy_atol) test_ad_rrule(svd_trunc, t; - fkwargs=(; trunc), output_tangent=(ΔU, ΔS, ΔVᴴ), atol, - rtol) + fkwargs=(; trunc), output_tangent=ΔUSVᴴ_trunc, atol, rtol) end end - # let (U, S, V) = tsvd(A) - # ΔU = randn(scalartype(U), space(U)) - # ΔS = randn(scalartype(S), space(S)) - # ΔV = randn(scalartype(V), space(V)) - # if T <: Complex # remove gauge dependent components - # gaugepart = U' * ΔU + V * ΔV' - # for (c, b) in blocks(gaugepart) - # mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1) - # end - # end - # test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV)) - - # allS = mapreduce(x -> diag(x[2]), vcat, blocks(S)) - # truncval = (maximum(allS) + minimum(allS)) / 2 - # U, S, V = tsvd(A; trunc=truncerror(; atol=truncval)) - # ΔU = randn(scalartype(U), space(U)) - # ΔS = randn(scalartype(S), space(S)) - # ΔV = randn(scalartype(V), space(V)) - # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - # test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV), - # fkwargs=(; trunc=truncerror(; atol=truncval))) - # end - - # let (U, S, V) = tsvd(B) - # ΔU = randn(scalartype(U), space(U)) - # ΔS = randn(scalartype(S), space(S)) - # ΔV = randn(scalartype(V), space(V)) - # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - # test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV)) - - # Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2) - # for (c, b) in blocks(S))) - - # U, S, V = tsvd(B; trunc=truncspace(Vtrunc)) - # ΔU = randn(scalartype(U), space(U)) - # ΔS = randn(scalartype(S), space(S)) - # ΔV = randn(scalartype(V), space(V)) - # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - # test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV), - # fkwargs=(; trunc=truncspace(Vtrunc))) - # end - - # let (U, S, V) = tsvd(C) - # ΔU = randn(scalartype(U), space(U)) - # ΔS = randn(scalartype(S), space(S)) - # ΔV = randn(scalartype(V), space(V)) - # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - # test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV)) - - # c, = argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])), blocks(S)) - # trunc = truncrank(round(Int, 2 * dim(c))) - # U, S, V = tsvd(C; trunc) - # ΔU = randn(scalartype(U), space(U)) - # ΔS = randn(scalartype(S), space(S)) - # ΔV = randn(scalartype(V), space(V)) - # T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V) - # test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV), fkwargs=(; trunc)) - # end - # let D = LinearAlgebra.eigvals(C) # ΔD = diag(randn(complex(scalartype(C)), space(C))) # test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD, diff --git a/test/runtests.jl b/test/runtests.jl index 03d7c183f..e2a1c8c75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,11 +110,11 @@ VSU₂U₁ = (Vect[SU2Irrep ⊠ U1Irrep]((0, 0) => 1, (1 // 2, -1) => 1), Vect[SU2Irrep ⊠ U1Irrep]((1 // 2, 1) => 1, (1, -2) => 1)', Vect[SU2Irrep ⊠ U1Irrep]((0, 0) => 2, (0, 2) => 1, (1 // 2, 1) => 1), Vect[SU2Irrep ⊠ U1Irrep]((0, 0) => 1, (1 // 2, 1) => 1)') -# VSU₃ = (ℂ[SU3Irrep]((0, 0, 0) => 3, (1, 0, 0) => 1), -# ℂ[SU3Irrep]((0, 0, 0) => 3, (2, 0, 0) => 1)', -# ℂ[SU3Irrep]((1, 1, 0) => 1, (2, 1, 0) => 1), -# ℂ[SU3Irrep]((1, 0, 0) => 1, (2, 0, 0) => 1), -# ℂ[SU3Irrep]((0, 0, 0) => 1, (1, 0, 0) => 1, (1, 1, 0) => 1)') +Vfib = (Vect[FibonacciAnyon](:I => 1, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 3, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2)) if !is_buildkite Ti = time() From 169f32490e43096191b3bed452ad0b0520c87cd0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 29 Sep 2025 17:08:23 -0400 Subject: [PATCH 109/126] some final small fixes --- src/tensors/factorizations/adjoint.jl | 6 ++ .../factorizations/matrixalgebrakit.jl | 8 -- test/ad.jl | 73 +++++++++++-------- test/bugfixes.jl | 4 +- 4 files changed, 50 insertions(+), 41 deletions(-) diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index b9f8acd52..6a4dafc67 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -78,6 +78,12 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!) $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end + + # disambiguate by prohibition + @eval function initialize_output(::typeof($f!), t::AdjointTensorMap, + alg::DiagonalAlgorithm) + throw(MethodError($f!, (t, alg))) + end end # avoid amgiguity function initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap, diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 560a210d5..02b9347ad 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -128,14 +128,6 @@ function check_input(::typeof(svd_compact!), t::AbstractTensorMap, USVᴴ, return nothing end -function check_input(::typeof(svd_vals!), t::AbstractTensorMap, S::SectorDict, - ::AbstractAlgorithm) - @check_scalar S t real - V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) - @check_space(S, V_cod ← V_dom) - return nothing -end - function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t real @assert D isa DiagonalTensorMap diff --git a/test/ad.jl b/test/ad.jl index 883b24578..dd820b21f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -148,22 +148,32 @@ project_hermitian(A) = (A + A') / 2 ChainRulesTestUtils.test_method_tables() -spacelist = try - if ENV["CI"] == "true" - println("Detected running on CI") - if Sys.iswindows() - (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂) - elseif Sys.isapple() - (Vtr, Vℤ₃, VfU₁, VfSU₂) - else - (Vtr, VU₁, VCU₁, VfSU₂, Vfib) - end - else - (Vtr, Vℤ₃, VU₁, VfU₁, VSU₂, VfSU₂, Vfib) - end -catch - (Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, Vfib) -end +spacelist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + (Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 3, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2)), + (Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2)), + (Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 3, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)'), + (Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'), + (Vect[FibonacciAnyon](:I => 1, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 3, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2))) for V in spacelist I = sectortype(eltype(V)) @@ -308,8 +318,8 @@ for V in spacelist @timedtestset "tensortrace!" begin for _ in 1:5 - k1 = rand(0:3) - k2 = k1 == 3 ? 1 : rand(1:2) + k1 = rand(0:2) + k2 = rand(1:2) V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) @@ -331,13 +341,13 @@ for V in spacelist end @timedtestset "tensoradd!" begin - A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) α = randn(T) β = randn(T) # repeat a couple times to get some distribution of arrows for _ in 1:5 - p = randindextuple(length(V)) + p = randindextuple(numind(A)) C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) @@ -399,19 +409,20 @@ for V in spacelist end @timedtestset "Factorizations" begin + W = V[1] ⊗ V[2] @testset "QR" begin for T in eltypes, t in (randn(T, W, W), randn(T, W, W)', - randn(T, W, V1), randn(T, V1, W), - randn(T, W, V1)', randn(T, V1, W)', - DiagonalTensorMap(randn(T, reduceddim(V1)), V1)) + randn(T, W, V[1]), randn(T, V[1], W), + randn(T, W, V[1])', randn(T, V[1], W)', + DiagonalTensorMap(randn(T, reduceddim(V[1])), V[1])) atol = rtol = precision(T) * dim(space(t)) fkwargs = (; positive=true) # make FiniteDifferences happy test_ad_rrule(qr_compact, t; fkwargs, atol, rtol) - test_ad_rrule(first ∘ qr_compact, t; fkwargs, atol, rtol,) - test_ad_rrule(last ∘ qr_compact, t; fkwargs, atol, rtol,) + test_ad_rrule(first ∘ qr_compact, t; fkwargs, atol, rtol) + test_ad_rrule(last ∘ qr_compact, t; fkwargs, atol, rtol) # qr_full/qr_null requires being careful with gauges Q, R = qr_full(t) @@ -445,9 +456,9 @@ for V in spacelist @testset "LQ" begin for T in eltypes, t in (randn(T, W, W), randn(T, W, W)', - randn(T, W, V1), randn(T, V1, W), - randn(T, W, V1)', randn(T, V1, W)', - DiagonalTensorMap(randn(T, reduceddim(V1)), V1)) + randn(T, W, V[1]), randn(T, V[1], W), + randn(T, W, V[1])', randn(T, V[1], W)', + DiagonalTensorMap(randn(T, reduceddim(V[1])), V[1])) atol = rtol = precision(T) * dim(space(t)) fkwargs = (; positive=true) # make FiniteDifferences happy @@ -489,8 +500,8 @@ for V in spacelist @testset "Eigenvalue decomposition" begin for T in eltypes, - t in (rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', - DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) + t in (rand(T, V[1], V[1]), rand(T, W, W), rand(T, W, W)', + DiagonalTensorMap(rand(T, reduceddim(V[1])), V[1])) atol = rtol = precision(T) * dim(space(t)) @@ -524,7 +535,7 @@ for V in spacelist @testset "Singular value decomposition" begin for T in eltypes, - t in (randn(T, V1, V1), randn(T, W, W), randn(T, W, W)) + t in (randn(T, V[1], V[1]), randn(T, W, W), randn(T, W, W)) # TODO: fix diagonaltensormap case # DiagonalTensorMap(rand(T, reduceddim(V1)), V1)) diff --git a/test/bugfixes.jl b/test/bugfixes.jl index fb2f0f45c..3b93f998a 100644 --- a/test/bugfixes.jl +++ b/test/bugfixes.jl @@ -47,7 +47,7 @@ # https://github.com/quantumkithub/TensorKit.jl/issues/201 @testset "Issue #201" begin function f(A::AbstractTensorMap) - U, S, V, = tsvd(A) + U, S, V, = svd_compact(A) return tr(S) end function f(A::AbstractMatrix) @@ -60,7 +60,7 @@ @test convert(Array, grad1) ≈ grad2 function g(A::AbstractTensorMap) - U, S, V, = tsvd(A) + U, S, V, = svd_compact(A) return tr(U * V) end function g(A::AbstractMatrix) From b66cacfb284b879ff916084dbee655c994503fc0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 30 Sep 2025 14:23:26 -0400 Subject: [PATCH 110/126] enable AD tests on CI --- test/runtests.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e2a1c8c75..ade9f7cb5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -124,8 +124,7 @@ if !is_buildkite include("factorizations.jl") include("diagonal.jl") include("planar.jl") - # TODO: remove once we know AD is slow on macOS CI - if !(Sys.isapple() && get(ENV, "CI", "false") == "true") && isempty(VERSION.prerelease) + if isempty(VERSION.prerelease) include("ad.jl") end include("bugfixes.jl") From bb9c05a25c11b711fbc408b28006b57537290b88 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 30 Sep 2025 19:28:41 -0400 Subject: [PATCH 111/126] move permutedcopy_oftype --- src/auxiliary/deprecate.jl | 4 ++++ src/tensors/factorizations/factorizations.jl | 2 +- src/tensors/factorizations/utility.jl | 3 --- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index 6417cb1ec..1d58fb119 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -69,6 +69,10 @@ _kindof(::DiagonalAlgorithm) = :svd # shouldn't really matter _drop_alg(; alg=nothing, kwargs...) = kwargs _drop_p(; p=nothing, kwargs...) = kwargs +function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple) + return permute!(similar(t, T, permute(space(t), p)), t, p) +end + # orthogonalization export leftorth, leftorth!, rightorth, rightorth! function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index b6a2355d1..70e4dbeba 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -11,7 +11,7 @@ export qr_full, qr_compact, qr_null export qr_full!, qr_compact!, qr_null! export lq_full, lq_compact, lq_null export lq_full!, lq_compact!, lq_null! -export copy_oftype, permutedcopy_oftype, factorisation_scalartype, one! +export copy_oftype, factorisation_scalartype, one! export TruncationScheme, notrunc, trunctol, truncerror, truncrank, truncspace, truncfilter, PolarViaSVD diff --git a/src/tensors/factorizations/utility.jl b/src/tensors/factorizations/utility.jl index 874b944d3..a6721ee31 100644 --- a/src/tensors/factorizations/utility.jl +++ b/src/tensors/factorizations/utility.jl @@ -12,9 +12,6 @@ function factorisation_scalartype(t::AbstractTensorMap) end factorisation_scalartype(f, t) = factorisation_scalartype(t) -function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple) - return permute!(similar(t, T, permute(space(t), p)), t, p) -end function copy_oftype(t::AbstractTensorMap, T::Type{<:Number}) return copy!(similar(t, T, space(t)), t) end From 5aab60223db62af074e906076f65ce93c004fb86 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 30 Sep 2025 19:30:33 -0400 Subject: [PATCH 112/126] change deprecation warning message --- src/auxiliary/deprecate.jl | 40 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/auxiliary/deprecate.jl b/src/auxiliary/deprecate.jl index 1d58fb119..5853cad01 100644 --- a/src/auxiliary/deprecate.jl +++ b/src/auxiliary/deprecate.jl @@ -76,23 +76,23 @@ end # orthogonalization export leftorth, leftorth!, rightorth, rightorth! function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) + Base.depwarn("`leftorth` is deprecated, use `left_orth` instead", :leftorth) return leftorth!(permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p); kwargs...) end function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) + Base.depwarn("`rightorth` is deprecated, use `right_orth` instead", :rightorth) return rightorth!(permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p); kwargs...) end function leftorth(t::AbstractTensorMap; kwargs...) - Base.depwarn("`leftorth` is no longer supported, use `left_orth` instead", :leftorth) + Base.depwarn("`leftorth` is deprecated, use `left_orth` instead", :leftorth) return leftorth!(copy_oftype(t, factorisation_scalartype(leftorth, t)); kwargs...) end function rightorth(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightorth` is no longer supported, use `right_orth` instead", :rightorth) + Base.depwarn("`rightorth` is deprecated, use `right_orth` instead", :rightorth) return rightorth!(copy_oftype(t, factorisation_scalartype(rightorth, t)); kwargs...) end function leftorth!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`leftorth!` is no longer supported, use `left_orth!` instead", :leftorth!) + Base.depwarn("`leftorth!` is deprecated, use `left_orth!` instead", :leftorth!) haskey(kwargs, :alg) || return left_orth!(t; kwargs...) alg = kwargs[:alg] kind = _kindof(alg) @@ -102,7 +102,7 @@ function leftorth!(t::AbstractTensorMap; kwargs...) throw(ArgumentError("invalid leftorth kind")) end function rightorth!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightorth!` is no longer supported, use `right_orth!` instead", :rightorth!) + Base.depwarn("`rightorth!` is deprecated, use `right_orth!` instead", :rightorth!) haskey(kwargs, :alg) || return right_orth!(t; kwargs...) alg = kwargs[:alg] kind = _kindof(alg) @@ -115,23 +115,23 @@ end # nullspaces export leftnull, leftnull!, rightnull, rightnull! function leftnull(t::AbstractTensorMap; kwargs...) - Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) + Base.depwarn("`leftnull` is deprecated, use `left_null` instead", :leftnull) return leftnull!(copy_oftype(t, factorisation_scalartype(leftnull, t)); kwargs...) end function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`leftnull` is no longer supported, use `left_null` instead", :leftnull) + Base.depwarn("`leftnull` is deprecated, use `left_null` instead", :leftnull) return leftnull!(permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p); kwargs...) end function rightnull(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) + Base.depwarn("`rightnull` is deprecated, use `right_null` instead", :rightnull) return rightnull!(copy_oftype(t, factorisation_scalartype(rightnull, t)); kwargs...) end function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...) - Base.depwarn("`rightnull` is no longer supported, use `right_null` instead", :rightnull) + Base.depwarn("`rightnull` is deprecated, use `right_null` instead", :rightnull) return rightnull!(permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p); kwargs...) end function leftnull!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`left_null!` is no longer supported, use `left_null!` instead", :leftnull!) + Base.depwarn("`left_null!` is deprecated, use `left_null!` instead", :leftnull!) haskey(kwargs, :alg) || return left_null!(t; kwargs...) alg = kwargs[:alg] kind = _kindof(alg) @@ -140,7 +140,7 @@ function leftnull!(t::AbstractTensorMap; kwargs...) throw(ArgumentError("invalid leftnull kind")) end function rightnull!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`rightnull!` is no longer supported, use `right_null!` instead", :rightnull!) + Base.depwarn("`rightnull!` is deprecated, use `right_null!` instead", :rightnull!) haskey(kwargs, :alg) || return right_null!(t; kwargs...) alg = kwargs[:alg] kind = _kindof(alg) @@ -159,19 +159,19 @@ export eig!, eigh!, eigen, eigen! eigen!(permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p); kwargs...), false) function eig(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eig` is no longer supported, use `eig_full` or `eig_trunc` instead", :eig) + Base.depwarn("`eig` is deprecated, use `eig_full` or `eig_trunc` instead", :eig) return haskey(kwargs, :trunc) ? eig_trunc(t; kwargs...) : eig_full(t; kwargs...) end function eig!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eig!` is no longer supported, use `eig_full!` or `eig_trunc!` instead", :eig!) + Base.depwarn("`eig!` is deprecated, use `eig_full!` or `eig_trunc!` instead", :eig!) return haskey(kwargs, :trunc) ? eig_trunc!(t; kwargs...) : eig_full!(t; kwargs...) end function eigh(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eigh` is no longer supported, use `eigh_full` or `eigh_trunc` instead", :eigh) + Base.depwarn("`eigh` is deprecated, use `eigh_full` or `eigh_trunc` instead", :eigh) return haskey(kwargs, :trunc) ? eigh_trunc(t; kwargs...) : eigh_full(t; kwargs...) end function eigh!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`eigh!` is no longer supported, use `eigh_full!` or `eigh_trunc!` instead", :eigh!) + Base.depwarn("`eigh!` is deprecated, use `eigh_full!` or `eigh_trunc!` instead", :eigh!) return haskey(kwargs, :trunc) ? eigh_trunc!(t; kwargs...) : eigh_full!(t; kwargs...) end @@ -180,17 +180,17 @@ export tsvd, tsvd! @deprecate(tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...), tsvd!(permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p); kwargs...)) function tsvd(t::AbstractTensorMap; kwargs...) - Base.depwarn("`tsvd` is no longer supported, use `svd_compact`, `svd_full` or `svd_trunc` instead", :tsvd) + Base.depwarn("`tsvd` is deprecated, use `svd_compact`, `svd_full` or `svd_trunc` instead", :tsvd) if haskey(kwargs, :p) - Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", :tsvd) + Base.depwarn("p is a deprecated kwarg, and should be specified through the truncation strategy", :tsvd) kwargs = _drop_p(; kwargs...) end return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : svd_compact(t; kwargs...) end function tsvd!(t::AbstractTensorMap; kwargs...) - Base.depwarn("`tsvd!` is no longer supported, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", :tsvd!) + Base.depwarn("`tsvd!` is deprecated, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", :tsvd!) if haskey(kwargs, :p) - Base.depwarn("p is no longer a supported kwarg, and should be specified through the truncation strategy", :tsvd!) + Base.depwarn("p is a deprecated kwarg, and should be specified through the truncation strategy", :tsvd!) kwargs = _drop_p(; kwargs...) end return haskey(kwargs, :trunc) ? svd_trunc!(t; kwargs...) : svd_compact!(t; kwargs...) From 59375448cfbc7a04380cab566a6e5130f98e2d78 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 30 Sep 2025 19:32:43 -0400 Subject: [PATCH 113/126] Apply suggestions from code review Co-authored-by: Jutho --- docs/src/lib/tensors.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index 167a83adc..4d1c5c9cc 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -216,10 +216,10 @@ contract! The factorisation methods are powered by [MatrixAlgebraKit.jl](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl) and all follow the same strategy. The idea is that the `TensorMap` is interpreted as a linear -map based on the current partition of indices between `domain` and `codomain`, and then the +map based on the current partition of indices between `domain` and `codomain`, and then the entire range of MatrixAlgebraKit functions can be called. -You can specify an additional permutation of the domain and codomain indices before the -factorisation is performed by making use of [`permute`](@ref) or [`transpose`](@ref). +Factorizing a tensor according to a different partition of the indices is possible +by prepending the factorization step with an explicit call to [`permute`](@ref) or [`transpose`](@ref). For the full list of factorizations, see [Decompositions](@extref MatrixAlgebraKit). From b76d03cfe83b533c19acb307c2e0fec83296d88a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 1 Oct 2025 17:30:51 -0400 Subject: [PATCH 114/126] update MatrixAlgebraKit version --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index ff1dd832a..9e17a4488 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorKit" uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" authors = ["Jutho Haegeman"] -version = "0.15.0-DEV" +version = "0.15.0" [deps] LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" @@ -33,7 +33,7 @@ Combinatorics = "1" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.4.1" +MatrixAlgebraKit = "0.5.0" OhMyThreads = "0.8.0" PackageExtensionCompat = "1" Random = "1" From 5631fe6142d69cee2c2e0946f54f2d90c88d4c71 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 1 Oct 2025 17:31:17 -0400 Subject: [PATCH 115/126] remove `truncate!` --- src/tensors/factorizations/adjoint.jl | 2 +- src/tensors/factorizations/factorizations.jl | 2 +- src/tensors/factorizations/truncation.jl | 19 ------------------- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 6a4dafc67..395bc05cf 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -93,5 +93,5 @@ end # to fix ambiguity function svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) - return truncate!(svd_trunc!, USVᴴ′, alg.trunc) + return truncate(svd_trunc!, USVᴴ′, alg.trunc) end diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 70e4dbeba..f6f05d0bd 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -40,7 +40,7 @@ import MatrixAlgebraKit: default_algorithm, left_polar!, left_orth_polar!, right_polar!, right_orth_polar!, left_null_svd!, right_null_svd!, left_orth_svd!, right_orth_svd!, left_orth!, right_orth!, left_null!, right_null!, - truncate!, findtruncated, findtruncated_svd, + truncate, findtruncated, findtruncated_svd, diagview, isisometry using MatrixAlgebraKit: qr_compact_pullback!, lq_compact_pullback!, svd_pullback!, svd_trunc_pullback!, diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 667a5f23d..76b02fda0 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -1,9 +1,5 @@ # Strategies # ---------- - -# TODO: deprecate -const TruncationScheme = TruncationStrategy - """ TruncationSpace(V::ElementarySpace, by::Function, rev::Bool) @@ -76,11 +72,6 @@ function truncate(::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3,AbstractTensorMap return (Ũ, S̃, Ṽᴴ), ind end -function truncate!(::typeof(svd_trunc!), USVᴴ::NTuple{3,AbstractTensorMap}, - strategy::TruncationStrategy) - USVᴴ_trunc, _ = truncate(svd_trunc!, USVᴴ, strategy) - return USVᴴ_trunc -end function truncate(::typeof(left_null!), (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, @@ -94,11 +85,6 @@ function truncate(::typeof(left_null!), truncate_domain!(Ũ, U, ind) return Ũ, ind end -function truncate!(::typeof(left_null!), US::NTuple{2,AbstractTensorMap}, - strategy::TruncationStrategy) - U_trunc, _ = truncate(left_null!, US, strategy) - return U_trunc -end for f! in (:eig_trunc!, :eigh_trunc!) @eval function truncate(::typeof($f!), @@ -115,11 +101,6 @@ for f! in (:eig_trunc!, :eigh_trunc!) return (D̃, Ṽ), ind end - @eval function truncate!(::typeof($f!), DV::Tuple{DiagonalTensorMap,AbstractTensorMap}, - strategy::TruncationStrategy) - DV_trunc, _ = truncate($f!, DV, strategy) - return DV_trunc - end end # Find truncation From 3c20424094aa65dfc8ec17a5edfeddfca224c2f5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 1 Oct 2025 17:31:32 -0400 Subject: [PATCH 116/126] adapt pullbacks --- src/tensors/factorizations/factorizations.jl | 3 ++- src/tensors/factorizations/pullbacks.jl | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index f6f05d0bd..658a0cc1d 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -42,7 +42,8 @@ import MatrixAlgebraKit: default_algorithm, left_orth!, right_orth!, left_null!, right_null!, truncate, findtruncated, findtruncated_svd, diagview, isisometry -using MatrixAlgebraKit: qr_compact_pullback!, lq_compact_pullback!, +using MatrixAlgebraKit: qr_pullback!, qr_null_pullback!, + lq_pullback!, lq_null_pullback!, svd_pullback!, svd_trunc_pullback!, eig_pullback!, eig_trunc_pullback!, eigh_pullback!, eigh_trunc_pullback!, diff --git a/src/tensors/factorizations/pullbacks.jl b/src/tensors/factorizations/pullbacks.jl index daa54f46d..15d395928 100644 --- a/src/tensors/factorizations/pullbacks.jl +++ b/src/tensors/factorizations/pullbacks.jl @@ -1,4 +1,4 @@ -for pullback! in (:qr_compact_pullback!, :lq_compact_pullback!, +for pullback! in (:qr_pullback!, :lq_pullback!, :left_polar_pullback!, :right_polar_pullback!) @eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, F, ΔF; kwargs...) @@ -10,6 +10,17 @@ for pullback! in (:qr_compact_pullback!, :lq_compact_pullback!, return Δt end end +for pullback! in (:qr_null_pullback!, :lq_null_pullback!) + @eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, + F, ΔF; kwargs...) + foreachblock(Δt, t) do c, (Δb, b) + Fc = block(F, c) + ΔFc = block(ΔF, c) + return $pullback!(Δb, b, Fc, ΔFc; kwargs...) + end + return Δt + end +end _notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t)) From 39f20b1a30ae8c49e6eb6995e1b93d9d98c4b18a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 1 Oct 2025 17:32:03 -0400 Subject: [PATCH 117/126] remove factorization chainrules --- .../TensorKitChainRulesCoreExt.jl | 7 - .../factorizations.jl | 209 ------------------ 2 files changed, 216 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl index 727693816..e89f79b55 100644 --- a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -12,16 +12,9 @@ import TensorOperations as TO using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract using VectorInterface: promote_scale, promote_add -using MatrixAlgebraKit -using MatrixAlgebraKit: TruncationStrategy, TruncatedAlgorithm, - svd_pullback!, eig_pullback!, eigh_pullback!, - qr_compact_pullback!, lq_compact_pullback!, - left_polar_pullback!, right_polar_pullback! - include("utility.jl") include("constructors.jl") include("linalg.jl") include("tensoroperations.jl") -include("factorizations.jl") end diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index 2746f5444..a104f408a 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,211 +1,2 @@ # Factorizations rules # -------------------- -function ChainRulesCore.rrule(::typeof(MatrixAlgebraKit.copy_input), f, - t::AbstractTensorMap) - project = ProjectTo(t) - copy_input_pullback(Δt) = (NoTangent(), NoTangent(), project(unthunk(Δt))) - return MatrixAlgebraKit.copy_input(f, t), copy_input_pullback -end - -@non_differentiable MatrixAlgebraKit.initialize_output(f, t::AbstractTensorMap, args...) -@non_differentiable MatrixAlgebraKit.check_input(f, t::AbstractTensorMap, args...) - -for qr_f in (:qr_compact, :qr_full) - qr_f! = Symbol(qr_f, '!') - @eval function ChainRulesCore.rrule(::typeof($qr_f!), t::AbstractTensorMap, QR, alg) - tc = MatrixAlgebraKit.copy_input($qr_f, t) - QR = $(qr_f!)(tc, QR, alg) - function qr_pullback(ΔQR′) - ΔQR = unthunk.(ΔQR′) - Δt = zerovector(t) - MatrixAlgebraKit.qr_compact_pullback!(Δt, t, QR, ΔQR) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - function qr_pullback(::Tuple{ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return QR, qr_pullback - end -end -function ChainRulesCore.rrule(::typeof(qr_null!), t::AbstractTensorMap, N, alg) - Q, R = qr_full(t, alg) - for (c, b) in blocks(t) - m, n = size(b) - copy!(block(N, c), view(block(Q, c), 1:m, (n + 1):m)) - end - - function qr_null_pullback(ΔN′) - ΔN = unthunk(ΔN′) - Δt = zerovector(t) - ΔQ = zerovector!(similar(Q, codomain(Q) ← fuse(codomain(Q)))) - foreachblock(ΔN) do c, (b,) - n = size(b, 2) - ΔQc = block(ΔQ, c) - return copy!(@view(ΔQc[:, (end - n + 1):end]), b) - end - ΔR = ZeroTangent() - MatrixAlgebraKit.qr_compact_pullback!(Δt, t, (Q, R), (ΔQ, ΔR)) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - qr_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - - return N, qr_null_pullback -end - -for lq_f in (:lq_compact, :lq_full) - lq_f! = Symbol(lq_f, '!') - @eval function ChainRulesCore.rrule(::typeof($lq_f!), t::AbstractTensorMap, LQ, alg) - tc = MatrixAlgebraKit.copy_input($lq_f, t) - LQ = $(lq_f!)(tc, LQ, alg) - function lq_pullback(ΔLQ′) - ΔLQ = unthunk.(ΔLQ′) - Δt = zerovector(t) - MatrixAlgebraKit.lq_compact_pullback!(Δt, t, LQ, ΔLQ) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - function lq_pullback(::Tuple{ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return LQ, lq_pullback - end -end -function ChainRulesCore.rrule(::typeof(lq_null!), t::AbstractTensorMap, Nᴴ, alg) - L, Q = lq_full(t, alg) - for (c, b) in blocks(t) - m, n = size(b) - copy!(block(Nᴴ, c), view(block(Q, c), (m + 1):n, 1:n)) - end - - function lq_null_pullback(ΔNᴴ′) - ΔNᴴ = unthunk(ΔNᴴ′) - Δt = zerovector(t) - ΔQ = zerovector!(similar(Q, codomain(Q) ← fuse(codomain(Q)))) - foreachblock(ΔNᴴ) do c, (b,) - m = size(b, 1) - ΔQc = block(ΔQ, c) - return copy!(@view(ΔQc[(end - m + 1):end, :]), b) - end - ΔL = ZeroTangent() - MatrixAlgebraKit.lq_compact_pullback!(Δt, t, (L, Q), (ΔL, ΔQ)) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - lq_null_pullback(::ZeroTangent) = NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - - return Nᴴ, lq_null_pullback -end - -for eig in (:eig, :eigh) - eig_f = Symbol(eig, "_full") - eig_f! = Symbol(eig_f, "!") - eig_f_pb! = Symbol(eig, "_pullback!") - eig_pb = Symbol(eig, "_pullback") - @eval function ChainRulesCore.rrule(::typeof($eig_f!), t::AbstractTensorMap, DV, alg) - tc = MatrixAlgebraKit.copy_input($eig_f, t) - DV = $(eig_f!)(tc, DV, alg) - function $eig_pb(ΔDV) - Δt = zerovector(t) - MatrixAlgebraKit.$eig_f_pb!(Δt, t, DV, unthunk.(ΔDV)) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - function $eig_pb(::Tuple{ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return DV, $eig_pb - end -end - -for svd_f in (:svd_compact, :svd_full) - svd_f! = Symbol(svd_f, "!") - @eval begin - function ChainRulesCore.rrule(::typeof($svd_f!), t::AbstractTensorMap, USVᴴ, alg) - tc = MatrixAlgebraKit.copy_input($svd_f, t) - USVᴴ = $(svd_f!)(tc, USVᴴ, alg) - function svd_pullback(ΔUSVᴴ) - Δt = zerovector(t) - MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ)) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - function svd_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return USVᴴ, svd_pullback - end - end -end - -function ChainRulesCore.rrule(::typeof(svd_trunc!), t::AbstractTensorMap, USVᴴ, - alg::TruncatedAlgorithm) - tc = MatrixAlgebraKit.copy_input(svd_compact, t) - USVᴴ = svd_compact!(tc, USVᴴ, alg.alg) - USVᴴ_trunc, ind = TensorKit.Factorizations.truncate(svd_trunc!, USVᴴ, alg.trunc) - svd_trunc_pullback = _make_svd_trunc_pullback(t, USVᴴ, ind) - return USVᴴ_trunc, svd_trunc_pullback -end -function _make_svd_trunc_pullback(t::AbstractTensorMap, USVᴴ, ind) - function svd_trunc_pullback(ΔUSVᴴ) - Δt = zerovector(t) - MatrixAlgebraKit.svd_pullback!(Δt, t, USVᴴ, unthunk.(ΔUSVᴴ), ind) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - function svd_trunc_pullback(::NTuple{3,ZeroTangent}) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return svd_trunc_pullback -end - -function ChainRulesCore.rrule(::typeof(left_polar!), t::AbstractTensorMap, WP, alg) - tc = MatrixAlgebraKit.copy_input(left_polar, t) - WP = left_polar!(tc, WP, alg) - function left_polar_pullback(ΔWP) - Δt = zerovector(t) - MatrixAlgebraKit.left_polar_pullback!(Δt, t, WP, unthunk.(ΔWP)) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - function left_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return WP, left_polar_pullback -end - -function ChainRulesCore.rrule(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, alg) - tc = MatrixAlgebraKit.copy_input(left_polar, t) - PWᴴ = right_polar!(tc, PWᴴ, alg) - function right_polar_pullback(ΔPWᴴ) - Δt = zerovector(t) - MatrixAlgebraKit.right_polar_pullback!(Δt, t, PWᴴ, unthunk.(ΔPWᴴ)) - return NoTangent(), Δt, ZeroTangent(), NoTangent() - end - function right_polar_pullback(::Tuple{ZeroTangent,ZeroTangent}) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return PWᴴ, right_polar_pullback -end - -# function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap) -# U, S, V⁺ = tsvd(t) -# s = diag(S) -# project_t = ProjectTo(t) - -# function svdvals_pullback(Δs′) -# Δs = unthunk(Δs′) -# ΔS = diagm(codomain(S), domain(S), Δs) -# return NoTangent(), project_t(U * ΔS * V⁺) -# end - -# return s, svdvals_pullback -# end - -# function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap; -# sortby=nothing, kwargs...) -# @assert sortby === nothing "only `sortby=nothing` is supported" -# (D, _), eig_pullback = rrule(TensorKit.eig!, t; kwargs...) -# d = diag(D) -# project_t = ProjectTo(t) -# function eigvals_pullback(Δd′) -# Δd = unthunk(Δd′) -# ΔD = diagm(codomain(D), domain(D), Δd) -# return NoTangent(), project_t(eig_pullback((ΔD, ZeroTangent()))[2]) -# end - -# return d, eigvals_pullback -# end From 9f6761fe163fec01e0169f7821bc17e5e9057f99 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 1 Oct 2025 17:39:18 -0400 Subject: [PATCH 118/126] remove some boilerplate and disambiguate --- src/tensors/factorizations/diagonal.jl | 3 +++ .../factorizations/matrixalgebrakit.jl | 26 +------------------ 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index b971d94d2..4b9e3221b 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -105,6 +105,9 @@ for f! in (:lq_full!, :lq_compact!) end end +# disambiguate +svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm) = svd_full!(t, USVᴴ, alg) + # f_vals # ------ diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 02b9347ad..a7452c4e6 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -154,6 +154,7 @@ function initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, return U, S, Vᴴ end +# TODO: remove this once `AbstractMatrix` specialization is removed in MatrixAlgebraKit function initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, t, alg.alg) @@ -165,11 +166,6 @@ function initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, return DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) end -function svd_trunc!(t::AbstractTensorMap, USVᴴ, alg::TruncatedAlgorithm) - USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) - return truncate!(svd_trunc!, USVᴴ′, alg.trunc) -end - # Eigenvalue decomposition # ------------------------ function check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) @@ -284,26 +280,6 @@ function initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, return D = DiagonalTensorMap{Tc}(undef, V_D) end -function initialize_output(::typeof(eigh_trunc!), t::AbstractTensorMap, - alg::TruncatedAlgorithm) - return initialize_output(eigh_full!, t, alg.alg) -end - -function initialize_output(::typeof(eig_trunc!), t::AbstractTensorMap, - alg::TruncatedAlgorithm) - return initialize_output(eig_full!, t, alg.alg) -end - -function eigh_trunc!(t::AbstractTensorMap, DV, alg::TruncatedAlgorithm) - DV′ = eigh_full!(t, DV, alg.alg) - return truncate!(eigh_trunc!, DV′, alg.trunc) -end - -function eig_trunc!(t::AbstractTensorMap, DV, alg::TruncatedAlgorithm) - DV′ = eig_full!(t, DV, alg.alg) - return truncate!(eig_trunc!, DV′, alg.trunc) -end - # QR decomposition # ---------------- function check_input(::typeof(qr_full!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) From 01d4f995f6bd449853c935496a6422ff04ddecfe Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 1 Oct 2025 21:14:30 -0400 Subject: [PATCH 119/126] more careful with import and exports --- src/TensorKit.jl | 2 +- src/tensors/factorizations/adjoint.jl | 83 ++++---- src/tensors/factorizations/diagonal.jl | 70 ++++--- src/tensors/factorizations/factorizations.jl | 45 +--- .../factorizations/matrixalgebrakit.jl | 192 ++++++++++-------- src/tensors/factorizations/pullbacks.jl | 28 +-- src/tensors/factorizations/truncation.jl | 66 +++--- src/tensors/factorizations/utility.jl | 2 +- 8 files changed, 242 insertions(+), 246 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 668a32976..235614904 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -109,7 +109,7 @@ using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend const TO = TensorOperations -using MatrixAlgebraKit: MatrixAlgebraKit as MAK +using MatrixAlgebraKit using LRUCache using OhMyThreads diff --git a/src/tensors/factorizations/adjoint.jl b/src/tensors/factorizations/adjoint.jl index 395bc05cf..b4c148788 100644 --- a/src/tensors/factorizations/adjoint.jl +++ b/src/tensors/factorizations/adjoint.jl @@ -2,63 +2,63 @@ # ---------------- # map algorithms to their adjoint counterpart # TODO: this probably belongs in MatrixAlgebraKit -_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.kwargs...) -_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.kwargs...) -_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.kwargs...) -_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.kwargs...) -_adjoint(alg::PolarViaSVD) = PolarViaSVD(_adjoint(alg.svdalg)) +_adjoint(alg::MAK.LAPACK_HouseholderQR) = MAK.LAPACK_HouseholderLQ(; alg.kwargs...) +_adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs...) +_adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...) +_adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...) +_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svdalg)) _adjoint(alg::AbstractAlgorithm) = alg # 1-arg functions -function initialize_output(::typeof(left_null!), t::AdjointTensorMap, - alg::AbstractAlgorithm) - return adjoint(initialize_output(right_null!, adjoint(t), _adjoint(alg))) +function MAK.initialize_output(::typeof(left_null!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return adjoint(MAK.initialize_output(right_null!, adjoint(t), _adjoint(alg))) end -function initialize_output(::typeof(right_null!), t::AdjointTensorMap, - alg::AbstractAlgorithm) - return adjoint(initialize_output(left_null!, adjoint(t), _adjoint(alg))) +function MAK.initialize_output(::typeof(right_null!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return adjoint(MAK.initialize_output(left_null!, adjoint(t), _adjoint(alg))) end -function left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) +function MAK.left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) right_null!(adjoint(t), adjoint(N), _adjoint(alg)) return N end -function right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) +function MAK.right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) left_null!(adjoint(t), adjoint(N), _adjoint(alg)) return N end -function MatrixAlgebraKit.is_left_isometry(t::AdjointTensorMap; kwargs...) +function MAK.is_left_isometry(t::AdjointTensorMap; kwargs...) return is_right_isometry(adjoint(t); kwargs...) end -function MatrixAlgebraKit.is_right_isometry(t::AdjointTensorMap; kwargs...) +function MAK.is_right_isometry(t::AdjointTensorMap; kwargs...) return is_left_isometry(adjoint(t); kwargs...) end # 2-arg functions for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!), (:lq_full!, :lq_compact!, :right_polar!, :right_orth!)) - @eval function copy_input(::typeof($left_f!), t::AdjointTensorMap) - return adjoint(copy_input($right_f!, adjoint(t))) + @eval function MAK.copy_input(::typeof($left_f!), t::AdjointTensorMap) + return adjoint(MAK.copy_input($right_f!, adjoint(t))) end - @eval function copy_input(::typeof($right_f!), t::AdjointTensorMap) - return adjoint(copy_input($left_f!, adjoint(t))) + @eval function MAK.copy_input(::typeof($right_f!), t::AdjointTensorMap) + return adjoint(MAK.copy_input($left_f!, adjoint(t))) end - @eval function initialize_output(::typeof($left_f!), t::AdjointTensorMap, - alg::AbstractAlgorithm) - return reverse(adjoint.(initialize_output($right_f!, adjoint(t), _adjoint(alg)))) + @eval function MAK.initialize_output(::typeof($left_f!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return reverse(adjoint.(MAK.initialize_output($right_f!, adjoint(t), _adjoint(alg)))) end - @eval function initialize_output(::typeof($right_f!), t::AdjointTensorMap, - alg::AbstractAlgorithm) - return reverse(adjoint.(initialize_output($left_f!, adjoint(t), _adjoint(alg)))) + @eval function MAK.initialize_output(::typeof($right_f!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return reverse(adjoint.(MAK.initialize_output($left_f!, adjoint(t), _adjoint(alg)))) end - @eval function $left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) + @eval function MAK.$left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end - @eval function $right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) + @eval function MAK.$right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end @@ -66,32 +66,35 @@ end # 3-arg functions for f! in (:svd_full!, :svd_compact!, :svd_trunc!) - @eval function copy_input(::typeof($f!), t::AdjointTensorMap) - return adjoint(copy_input($f!, adjoint(t))) + @eval function MAK.copy_input(::typeof($f!), t::AdjointTensorMap) + return adjoint(MAK.copy_input($f!, adjoint(t))) end - @eval function initialize_output(::typeof($f!), t::AdjointTensorMap, - alg::AbstractAlgorithm) - return reverse(adjoint.(initialize_output($f!, adjoint(t), _adjoint(alg)))) + @eval function MAK.initialize_output(::typeof($f!), t::AdjointTensorMap, + alg::AbstractAlgorithm) + return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg)))) end - @eval function $f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) + @eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return F end # disambiguate by prohibition - @eval function initialize_output(::typeof($f!), t::AdjointTensorMap, - alg::DiagonalAlgorithm) + @eval function MAK.initialize_output(::typeof($f!), t::AdjointTensorMap, + alg::DiagonalAlgorithm) throw(MethodError($f!, (t, alg))) end end # avoid amgiguity -function initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap, - alg::TruncatedAlgorithm) - return initialize_output(svd_compact!, t, alg.alg) +function MAK.initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap, + alg::TruncatedAlgorithm) + return MAK.initialize_output(svd_compact!, t, alg.alg) end # to fix ambiguity -function svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm) +function MAK.svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg) - return truncate(svd_trunc!, USVᴴ′, alg.trunc) + return MAK.truncate(svd_trunc!, USVᴴ′, alg.trunc) +end +function MAK.svd_compact!(t::AdjointTensorMap, USVᴴ, alg::DiagonalAlgorithm) + return MAK.svd_compact!(t, USVᴴ, alg.alg) end diff --git a/src/tensors/factorizations/diagonal.jl b/src/tensors/factorizations/diagonal.jl index 4b9e3221b..1cbaedafe 100644 --- a/src/tensors/factorizations/diagonal.jl +++ b/src/tensors/factorizations/diagonal.jl @@ -5,19 +5,19 @@ _repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data) for f in (:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, :left_polar, :right_polar) - @eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) + @eval MAK.copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) end for f! in (:eig_full!, :eig_trunc!) - @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, - ::DiagonalAlgorithm) + @eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) return d, similar(d) end end for f! in (:eigh_full!, :eigh_trunc!) - @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, - ::DiagonalAlgorithm) + @eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) if scalartype(d) <: Real return d, similar(d) else @@ -27,36 +27,37 @@ for f! in (:eigh_full!, :eigh_trunc!) end for f! in (:qr_full!, :qr_compact!) - @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, - ::DiagonalAlgorithm) + @eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) return d, similar(d) end # to avoid ambiguities - @eval function initialize_output(::typeof($f!), d::AdjointTensorMap, - ::DiagonalAlgorithm) + @eval function MAK.initialize_output(::typeof($f!), d::AdjointTensorMap, + ::DiagonalAlgorithm) return d, similar(d) end end for f! in (:lq_full!, :lq_compact!) - @eval function initialize_output(::typeof($f!), d::AbstractTensorMap, - ::DiagonalAlgorithm) + @eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap, + ::DiagonalAlgorithm) return similar(d), d end # to avoid ambiguities - @eval function initialize_output(::typeof($f!), d::AdjointTensorMap, - ::DiagonalAlgorithm) + @eval function MAK.initialize_output(::typeof($f!), d::AdjointTensorMap, + ::DiagonalAlgorithm) return similar(d), d end end -function initialize_output(::typeof(left_orth!), d::DiagonalTensorMap) +function MAK.initialize_output(::typeof(left_orth!), d::DiagonalTensorMap) return d, similar(d) end -function initialize_output(::typeof(right_orth!), d::DiagonalTensorMap) +function MAK.initialize_output(::typeof(right_orth!), d::DiagonalTensorMap) return similar(d), d end -function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::DiagonalAlgorithm) +function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, + ::DiagonalAlgorithm) V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) U = similar(t, codomain(t) ← V_cod) @@ -68,16 +69,16 @@ end for f! in (:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!, :eigh_trunc!, :right_orth!, :left_orth!) - @eval function $f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm) - check_input($f!, d, F, alg) + @eval function MAK.$f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm) + MAK.check_input($f!, d, F, alg) $f!(_repack_diagonal(d), _repack_diagonal.(F), alg) return F end end for f! in (:qr_full!, :qr_compact!) - @eval function check_input(::typeof($f!), d::AbstractTensorMap, QR, - ::DiagonalAlgorithm) + @eval function MAK.check_input(::typeof($f!), d::AbstractTensorMap, QR, + ::DiagonalAlgorithm) Q, R = QR @assert d isa DiagonalTensorMap @assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap @@ -91,8 +92,8 @@ for f! in (:qr_full!, :qr_compact!) end for f! in (:lq_full!, :lq_compact!) - @eval function check_input(::typeof($f!), d::AbstractTensorMap, LQ, - ::DiagonalAlgorithm) + @eval function MAK.check_input(::typeof($f!), d::AbstractTensorMap, LQ, + ::DiagonalAlgorithm) L, Q = LQ @assert d isa DiagonalTensorMap @assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap @@ -106,25 +107,27 @@ for f! in (:lq_full!, :lq_compact!) end # disambiguate -svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm) = svd_full!(t, USVᴴ, alg) +function MAK.svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm) + return svd_full!(t, USVᴴ, alg) +end # f_vals # ------ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) - @eval function $f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm) - check_input($f!, d, V, alg) + @eval function MAK.$f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm) + MAK.check_input($f!, d, V, alg) $f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg) return V end - @eval function initialize_output(::typeof($f!), d::DiagonalTensorMap, - alg::DiagonalAlgorithm) - data = initialize_output($f!, _repack_diagonal(d), alg) + @eval function MAK.initialize_output(::typeof($f!), d::DiagonalTensorMap, + alg::DiagonalAlgorithm) + data = MAK.initialize_output($f!, _repack_diagonal(d), alg) return DiagonalTensorMap(data, d.domain) end end -function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm) +function MAK.check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -144,7 +147,8 @@ function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAl return nothing end -function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm) +function MAK.check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, + ::DiagonalAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -164,21 +168,21 @@ function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, ::DiagonalA return nothing end -function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) +function MAK.check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) @assert D isa DiagonalTensorMap @check_scalar D t @check_space D space(t) return nothing end -function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) +function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) @assert D isa DiagonalTensorMap @check_scalar D t real @check_space D space(t) return nothing end -function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) +function MAK.check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm) @assert D isa DiagonalTensorMap @check_scalar D t real @check_space D space(t) diff --git a/src/tensors/factorizations/factorizations.jl b/src/tensors/factorizations/factorizations.jl index 658a0cc1d..310bf10a2 100644 --- a/src/tensors/factorizations/factorizations.jl +++ b/src/tensors/factorizations/factorizations.jl @@ -3,51 +3,24 @@ # using submodule here to import MatrixAlgebraKit functions without polluting namespace module Factorizations -export eig, eig!, eigh, eigh! -export tsvd, tsvd!, svdvals, svdvals! -export leftorth, leftorth!, rightorth, rightorth! -export leftnull, leftnull!, rightnull, rightnull! -export qr_full, qr_compact, qr_null -export qr_full!, qr_compact!, qr_null! -export lq_full, lq_compact, lq_null -export lq_full!, lq_compact!, lq_null! -export copy_oftype, factorisation_scalartype, one! -export TruncationScheme, notrunc, trunctol, truncerror, truncrank, truncspace, truncfilter, - PolarViaSVD +export copy_oftype, factorisation_scalartype, one!, truncspace using ..TensorKit using ..TensorKit: AdjointTensorMap, SectorDict, blocktype, foreachblock, one! -using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals! -import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian +using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!, eigen, eigen!, + isposdef, isposdef!, ishermitian using TensorOperations: Index2Tuple using MatrixAlgebraKit +import MatrixAlgebraKit as MAK using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder -using MatrixAlgebraKit: PolarViaSVD -using MatrixAlgebraKit: LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR, - LAPACK_HouseholderLQ, LAPACK_HouseholderQL, LAPACK_HouseholderRQ -import MatrixAlgebraKit: default_algorithm, - copy_input, check_input, initialize_output, - qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!, - svd_compact!, svd_full!, svd_trunc!, svd_vals!, - eigh_full!, eigh_trunc!, eigh_vals!, - eig_full!, eig_trunc!, eig_vals!, - left_polar!, left_orth_polar!, right_polar!, right_orth_polar!, - left_null_svd!, right_null_svd!, left_orth_svd!, right_orth_svd!, - left_orth!, right_orth!, left_null!, right_null!, - truncate, findtruncated, findtruncated_svd, - diagview, isisometry -using MatrixAlgebraKit: qr_pullback!, qr_null_pullback!, - lq_pullback!, lq_null_pullback!, - svd_pullback!, svd_trunc_pullback!, - eig_pullback!, eig_trunc_pullback!, - eigh_pullback!, eigh_trunc_pullback!, - left_polar_pullback!, right_polar_pullback! +using MatrixAlgebraKit: left_orth_polar!, right_orth_polar!, left_orth_svd!, + right_orth_svd!, left_null_svd!, right_null_svd!, diagview include("utility.jl") include("matrixalgebrakit.jl") @@ -58,7 +31,7 @@ include("pullbacks.jl") TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A) -function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) +function MatrixAlgebraKit.isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) t = permute(t, (p₁, p₂); copy=false) return isisometry(t) end @@ -67,10 +40,10 @@ end # LinearAlgebra overloads #------------------------------# -function eigen(t::AbstractTensorMap; kwargs...) +function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...) return ishermitian(t) ? eigh_full(t; kwargs...) : eig_full(t; kwargs...) end -function eigen!(t::AbstractTensorMap; kwargs...) +function LinearAlgebra.eigen!(t::AbstractTensorMap; kwargs...) return ishermitian(t) ? eigh_full!(t; kwargs...) : eig_full!(t; kwargs...) end diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index a7452c4e6..93c1fe312 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -5,21 +5,18 @@ for f in :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, :left_polar, :right_polar] f! = Symbol(f, :!) - @eval function default_algorithm(::typeof($f!), ::Type{T}; - kwargs...) where {T<:AbstractTensorMap} - return default_algorithm($f!, blocktype(T); kwargs...) + @eval function MAK.default_algorithm(::typeof($f!), ::Type{T}; + kwargs...) where {T<:AbstractTensorMap} + return MAK.default_algorithm($f!, blocktype(T); kwargs...) end - @eval function copy_input(::typeof($f), t::AbstractTensorMap) + @eval function MAK.copy_input(::typeof($f), t::AbstractTensorMap) return copy_oftype(t, factorisation_scalartype($f, t)) end end -function _select_truncation(f, ::AbstractTensorMap, - trunc::MatrixAlgebraKit.TruncationStrategy) - return trunc -end +_select_truncation(f, ::AbstractTensorMap, trunc::TruncationStrategy) = trunc function _select_truncation(::typeof(left_null!), ::AbstractTensorMap, trunc::NamedTuple) - return MatrixAlgebraKit.null_truncation_strategy(; trunc...) + return MAK.null_truncation_strategy(; trunc...) end # Generic Implementations @@ -31,8 +28,8 @@ for f! in (:qr_compact!, :qr_full!, :left_polar!, :left_orth_polar!, :right_polar!, :right_orth_polar!, :left_orth!, :right_orth!) - @eval function $f!(t::AbstractTensorMap, F, alg::AbstractAlgorithm) - check_input($f!, t, F, alg) + @eval function MAK.$f!(t::AbstractTensorMap, F, alg::AbstractAlgorithm) + MAK.check_input($f!, t, F, alg) foreachblock(t, F...) do _, bs factors = Base.tail(bs) @@ -50,8 +47,8 @@ end # Handle these separately because single output instead of tuple for f! in (:qr_null!, :lq_null!) - @eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) - check_input($f!, t, N, alg) + @eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) + MAK.check_input($f!, t, N, alg) foreachblock(t, N) do _, (b, n) n′ = $f!(b, n, alg) @@ -66,8 +63,8 @@ end # Handle these separately because single output instead of tuple for f! in (:svd_vals!, :eig_vals!, :eigh_vals!) - @eval function $f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) - check_input($f!, t, N, alg) + @eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) + MAK.check_input($f!, t, N, alg) foreachblock(t, N) do _, (b, n) n′ = $f!(b, n.diag, alg) @@ -82,7 +79,8 @@ end # Singular value decomposition # ---------------------------- -function check_input(::typeof(svd_full!), t::AbstractTensorMap, USVᴴ, ::AbstractAlgorithm) +function MAK.check_input(::typeof(svd_full!), t::AbstractTensorMap, USVᴴ, + ::AbstractAlgorithm) U, S, Vᴴ = USVᴴ # type checks @@ -105,8 +103,8 @@ function check_input(::typeof(svd_full!), t::AbstractTensorMap, USVᴴ, ::Abstra return nothing end -function check_input(::typeof(svd_compact!), t::AbstractTensorMap, USVᴴ, - ::AbstractAlgorithm) +function MAK.check_input(::typeof(svd_compact!), t::AbstractTensorMap, USVᴴ, + ::AbstractAlgorithm) U, S, Vᴴ = USVᴴ # type checks @@ -128,7 +126,7 @@ function check_input(::typeof(svd_compact!), t::AbstractTensorMap, USVᴴ, return nothing end -function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) +function MAK.check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t real @assert D isa DiagonalTensorMap V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) @@ -136,7 +134,8 @@ function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::AbstractAlg return nothing end -function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_cod = fuse(codomain(t)) V_dom = fuse(domain(t)) U = similar(t, codomain(t) ← V_cod) @@ -145,8 +144,8 @@ function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::Abstract return U, S, Vᴴ end -function initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, - ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) U = similar(t, codomain(t) ← V_cod) S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) @@ -155,20 +154,21 @@ function initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, end # TODO: remove this once `AbstractMatrix` specialization is removed in MatrixAlgebraKit -function initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, - alg::TruncatedAlgorithm) - return initialize_output(svd_compact!, t, alg.alg) +function MAK.initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap, + alg::TruncatedAlgorithm) + return MAK.initialize_output(svd_compact!, t, alg.alg) end -function initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, - alg::AbstractAlgorithm) +function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, + alg::AbstractAlgorithm) V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) return DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) end # Eigenvalue decomposition # ------------------------ -function check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) +function MAK.check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, + ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -190,7 +190,7 @@ function check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::AbstractA return nothing end -function check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) +function MAK.check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -212,7 +212,7 @@ function check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::AbstractAl return nothing end -function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::AbstractAlgorithm) +function MAK.check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::AbstractAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -234,7 +234,7 @@ function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::AbstractAl return nothing end -function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) +function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t real @assert D isa DiagonalTensorMap V_D = fuse(domain(t)) @@ -242,7 +242,7 @@ function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::AbstractAl return nothing end -function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) +function MAK.check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t complex @assert D isa DiagonalTensorMap V_D = fuse(domain(t)) @@ -250,7 +250,8 @@ function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::AbstractAlg return nothing end -function initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_D = fuse(domain(t)) T = real(scalartype(t)) D = DiagonalTensorMap{T}(undef, V_D) @@ -258,7 +259,8 @@ function initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::Abstrac return D, V end -function initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(eig_full!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_D = fuse(domain(t)) Tc = complex(scalartype(t)) D = DiagonalTensorMap{Tc}(undef, V_D) @@ -266,15 +268,15 @@ function initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::Abstract return D, V end -function initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, - alg::AbstractAlgorithm) +function MAK.initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, + alg::AbstractAlgorithm) V_D = fuse(domain(t)) T = real(scalartype(t)) return D = DiagonalTensorMap{Tc}(undef, V_D) end -function initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, - alg::AbstractAlgorithm) +function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, + alg::AbstractAlgorithm) V_D = fuse(domain(t)) Tc = complex(scalartype(t)) return D = DiagonalTensorMap{Tc}(undef, V_D) @@ -282,7 +284,7 @@ end # QR decomposition # ---------------- -function check_input(::typeof(qr_full!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) +function MAK.check_input(::typeof(qr_full!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) Q, R = QR # type checks @@ -301,7 +303,8 @@ function check_input(::typeof(qr_full!), t::AbstractTensorMap, QR, ::AbstractAlg return nothing end -function check_input(::typeof(qr_compact!), t::AbstractTensorMap, QR, ::AbstractAlgorithm) +function MAK.check_input(::typeof(qr_compact!), t::AbstractTensorMap, QR, + ::AbstractAlgorithm) Q, R = QR # type checks @@ -320,7 +323,7 @@ function check_input(::typeof(qr_compact!), t::AbstractTensorMap, QR, ::Abstract return nothing end -function check_input(::typeof(qr_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) +function MAK.check_input(::typeof(qr_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -332,21 +335,24 @@ function check_input(::typeof(qr_null!), t::AbstractTensorMap, N, ::AbstractAlgo return nothing end -function initialize_output(::typeof(qr_full!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(qr_full!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_Q = fuse(codomain(t)) Q = similar(t, codomain(t) ← V_Q) R = similar(t, V_Q ← domain(t)) return Q, R end -function initialize_output(::typeof(qr_compact!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(qr_compact!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) Q = similar(t, codomain(t) ← V_Q) R = similar(t, V_Q ← domain(t)) return Q, R end -function initialize_output(::typeof(qr_null!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(qr_null!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) V_N = ⊖(fuse(codomain(t)), V_Q) N = similar(t, codomain(t) ← V_N) @@ -355,7 +361,7 @@ end # LQ decomposition # ---------------- -function check_input(::typeof(lq_full!), t::AbstractTensorMap, LQ, ::AbstractAlgorithm) +function MAK.check_input(::typeof(lq_full!), t::AbstractTensorMap, LQ, ::AbstractAlgorithm) L, Q = LQ # type checks @@ -374,7 +380,8 @@ function check_input(::typeof(lq_full!), t::AbstractTensorMap, LQ, ::AbstractAlg return nothing end -function check_input(::typeof(lq_compact!), t::AbstractTensorMap, LQ, ::AbstractAlgorithm) +function MAK.check_input(::typeof(lq_compact!), t::AbstractTensorMap, LQ, + ::AbstractAlgorithm) L, Q = LQ # type checks @@ -393,7 +400,7 @@ function check_input(::typeof(lq_compact!), t::AbstractTensorMap, LQ, ::Abstract return nothing end -function check_input(::typeof(lq_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) +function MAK.check_input(::typeof(lq_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -405,21 +412,24 @@ function check_input(::typeof(lq_null!), t::AbstractTensorMap, N, ::AbstractAlgo return nothing end -function initialize_output(::typeof(lq_full!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(lq_full!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_Q = fuse(domain(t)) L = similar(t, codomain(t) ← V_Q) Q = similar(t, V_Q ← domain(t)) return L, Q end -function initialize_output(::typeof(lq_compact!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(lq_compact!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) L = similar(t, codomain(t) ← V_Q) Q = similar(t, V_Q ← domain(t)) return L, Q end -function initialize_output(::typeof(lq_null!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(lq_null!), t::AbstractTensorMap, + ::AbstractAlgorithm) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) V_N = ⊖(fuse(domain(t)), V_Q) N = similar(t, V_N ← domain(t)) @@ -428,7 +438,8 @@ end # Polar decomposition # ------------------- -function check_input(::typeof(left_polar!), t::AbstractTensorMap, WP, ::AbstractAlgorithm) +function MAK.check_input(::typeof(left_polar!), t::AbstractTensorMap, WP, + ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) @@ -447,8 +458,8 @@ function check_input(::typeof(left_polar!), t::AbstractTensorMap, WP, ::Abstract return nothing end -function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, WP, - ::AbstractAlgorithm) +function MAK.check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, WP, + ::AbstractAlgorithm) codomain(t) ≿ domain(t) || throw(ArgumentError("Polar decomposition requires `codomain(t) ≿ domain(t)`")) @@ -468,13 +479,15 @@ function check_input(::typeof(left_orth_polar!), t::AbstractTensorMap, WP, return nothing end -function initialize_output(::typeof(left_polar!), t::AbstractTensorMap, ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(left_polar!), t::AbstractTensorMap, + ::AbstractAlgorithm) W = similar(t, space(t)) P = similar(t, domain(t) ← domain(t)) return W, P end -function check_input(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, ::AbstractAlgorithm) +function MAK.check_input(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, + ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) @@ -493,8 +506,8 @@ function check_input(::typeof(right_polar!), t::AbstractTensorMap, PWᴴ, ::Abst return nothing end -function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, PWᴴ, - ::AbstractAlgorithm) +function MAK.check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, PWᴴ, + ::AbstractAlgorithm) codomain(t) ≾ domain(t) || throw(ArgumentError("Polar decomposition requires `domain(t) ≿ codomain(t)`")) @@ -514,8 +527,8 @@ function check_input(::typeof(right_orth_polar!), t::AbstractTensorMap, PWᴴ, return nothing end -function initialize_output(::typeof(right_polar!), t::AbstractTensorMap, - ::AbstractAlgorithm) +function MAK.initialize_output(::typeof(right_polar!), t::AbstractTensorMap, + ::AbstractAlgorithm) P = similar(t, codomain(t) ← codomain(t)) Wᴴ = similar(t, space(t)) return P, Wᴴ @@ -523,7 +536,8 @@ end # Orthogonalization # ----------------- -function check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, ::AbstractAlgorithm) +function MAK.check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, + ::AbstractAlgorithm) V, C = VC # scalartype checks @@ -538,7 +552,8 @@ function check_input(::typeof(left_orth!), t::AbstractTensorMap, VC, ::AbstractA return nothing end -function check_input(::typeof(right_orth!), t::AbstractTensorMap, CVᴴ, ::AbstractAlgorithm) +function MAK.check_input(::typeof(right_orth!), t::AbstractTensorMap, CVᴴ, + ::AbstractAlgorithm) C, Vᴴ = CVᴴ # scalartype checks @@ -553,14 +568,14 @@ function check_input(::typeof(right_orth!), t::AbstractTensorMap, CVᴴ, ::Abstr return nothing end -function initialize_output(::typeof(left_orth!), t::AbstractTensorMap) +function MAK.initialize_output(::typeof(left_orth!), t::AbstractTensorMap) V_C = infimum(fuse(codomain(t)), fuse(domain(t))) V = similar(t, codomain(t) ← V_C) C = similar(t, V_C ← domain(t)) return V, C end -function initialize_output(::typeof(right_orth!), t::AbstractTensorMap) +function MAK.initialize_output(::typeof(right_orth!), t::AbstractTensorMap) V_C = infimum(fuse(codomain(t)), fuse(domain(t))) C = similar(t, codomain(t) ← V_C) Vᴴ = similar(t, V_C ← domain(t)) @@ -572,10 +587,10 @@ end # providing output arguments for left_ and right_orth. # This is mainly because polar decompositions have different shapes, and SVD for Diagonal # also does -function left_orth!(t::AbstractTensorMap; - trunc::TruncationStrategy=notrunc(), - kind=trunc == notrunc() ? :qr : :svd, - alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;)) +function MAK.left_orth!(t::AbstractTensorMap; + trunc::TruncationStrategy=notrunc(), + kind=trunc == notrunc() ? :qr : :svd, + alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;)) trunc == notrunc() || kind === :svd || throw(ArgumentError("truncation not supported for left_orth with kind = $kind")) @@ -591,10 +606,10 @@ function left_orth!(t::AbstractTensorMap; throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`")) end end -function right_orth!(t::AbstractTensorMap; - trunc::TruncationStrategy=notrunc(), - kind=trunc == notrunc() ? :lq : :svd, - alg_lq=(; positive=true), alg_polar=(;), alg_svd=(;)) +function MAK.right_orth!(t::AbstractTensorMap; + trunc::TruncationStrategy=notrunc(), + kind=trunc == notrunc() ? :lq : :svd, + alg_lq=(; positive=true), alg_polar=(;), alg_svd=(;)) trunc == notrunc() || kind === :svd || throw(ArgumentError("truncation not supported for right_orth with kind = $kind")) @@ -611,31 +626,31 @@ function right_orth!(t::AbstractTensorMap; end end -function left_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...) - alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg; kwargs...) - VC = initialize_output(left_orth!, t) +function MAK.left_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...) + alg′ = MAK.select_algorithm(left_polar!, t, alg; kwargs...) + VC = MAK.initialize_output(left_orth!, t) return left_orth_polar!(t, VC, alg′) end -function left_orth_polar!(t::AbstractTensorMap, VC, alg) - alg′ = MatrixAlgebraKit.select_algorithm(left_polar!, t, alg) +function MAK.left_orth_polar!(t::AbstractTensorMap, VC, alg) + alg′ = MAK.select_algorithm(left_polar!, t, alg) return left_orth_polar!(t, VC, alg′) end -function right_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...) - alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg; kwargs...) - CVᴴ = initialize_output(right_orth!, t) +function MAK.right_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...) + alg′ = MAK.select_algorithm(right_polar!, t, alg; kwargs...) + CVᴴ = MAK.initialize_output(right_orth!, t) return right_orth_polar!(t, CVᴴ, alg′) end -function right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg) - alg′ = MatrixAlgebraKit.select_algorithm(right_polar!, t, alg) +function MAK.right_orth_polar!(t::AbstractTensorMap, CVᴴ, alg) + alg′ = MAK.select_algorithm(right_polar!, t, alg) return right_orth_polar!(t, CVᴴ, alg′) end -function left_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) +function MAK.left_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) : svd_trunc!(t; trunc, kwargs...) return U, lmul!(S, Vᴴ) end -function right_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) +function MAK.right_orth_svd!(t::AbstractTensorMap; trunc=notrunc(), kwargs...) U, S, Vᴴ = trunc == notrunc() ? svd_compact!(t; kwargs...) : svd_trunc!(t; trunc, kwargs...) return rmul!(U, S), Vᴴ @@ -643,7 +658,7 @@ end # Nullspace # --------- -function check_input(::typeof(left_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) +function MAK.check_input(::typeof(left_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) # scalartype checks @check_scalar N t @@ -655,7 +670,8 @@ function check_input(::typeof(left_null!), t::AbstractTensorMap, N, ::AbstractAl return nothing end -function check_input(::typeof(right_null!), t::AbstractTensorMap, N, ::AbstractAlgorithm) +function MAK.check_input(::typeof(right_null!), t::AbstractTensorMap, N, + ::AbstractAlgorithm) @check_scalar N t # space checks @@ -666,14 +682,14 @@ function check_input(::typeof(right_null!), t::AbstractTensorMap, N, ::AbstractA return nothing end -function initialize_output(::typeof(left_null!), t::AbstractTensorMap) +function MAK.initialize_output(::typeof(left_null!), t::AbstractTensorMap) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) V_N = ⊖(fuse(codomain(t)), V_Q) N = similar(t, codomain(t) ← V_N) return N end -function initialize_output(::typeof(right_null!), t::AbstractTensorMap) +function MAK.initialize_output(::typeof(right_null!), t::AbstractTensorMap) V_Q = infimum(fuse(codomain(t)), fuse(domain(t))) V_N = ⊖(fuse(domain(t)), V_Q) N = similar(t, V_N ← domain(t)) @@ -681,7 +697,7 @@ function initialize_output(::typeof(right_null!), t::AbstractTensorMap) end for (f!, f_svd!) in zip((:left_null!, :right_null!), (:left_null_svd!, :right_null_svd!)) - @eval function $f_svd!(t::AbstractTensorMap, N, alg, ::Nothing=nothing) + @eval function MAK.$f_svd!(t::AbstractTensorMap, N, alg, ::Nothing=nothing) return $f!(t, N; alg_svd=alg) end end diff --git a/src/tensors/factorizations/pullbacks.jl b/src/tensors/factorizations/pullbacks.jl index 15d395928..392304bd5 100644 --- a/src/tensors/factorizations/pullbacks.jl +++ b/src/tensors/factorizations/pullbacks.jl @@ -1,22 +1,22 @@ for pullback! in (:qr_pullback!, :lq_pullback!, :left_polar_pullback!, :right_polar_pullback!) - @eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, - F, ΔF; kwargs...) + @eval function MAK.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, + F, ΔF; kwargs...) foreachblock(Δt, t) do c, (Δb, b) Fc = block.(F, Ref(c)) ΔFc = block.(ΔF, Ref(c)) - return $pullback!(Δb, b, Fc, ΔFc; kwargs...) + return MAK.$pullback!(Δb, b, Fc, ΔFc; kwargs...) end return Δt end end for pullback! in (:qr_null_pullback!, :lq_null_pullback!) - @eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, - F, ΔF; kwargs...) + @eval function MAK.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, + F, ΔF; kwargs...) foreachblock(Δt, t) do c, (Δb, b) Fc = block(F, c) ΔFc = block(ΔF, c) - return $pullback!(Δb, b, Fc, ΔFc; kwargs...) + return MAK.$pullback!(Δb, b, Fc, ΔFc; kwargs...) end return Δt end @@ -25,28 +25,28 @@ end _notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t)) for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!) - @eval function MatrixAlgebraKit.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, - F, ΔF, inds=_notrunc_ind(t); - kwargs...) + @eval function MAK.$pullback!(Δt::AbstractTensorMap, t::AbstractTensorMap, + F, ΔF, inds=_notrunc_ind(t); + kwargs...) for (c, ind) in inds Δb = block(Δt, c) b = block(t, c) Fc = block.(F, Ref(c)) ΔFc = block.(ΔF, Ref(c)) - $pullback!(Δb, b, Fc, ΔFc, ind; kwargs...) + MAK.$pullback!(Δb, b, Fc, ΔFc, ind; kwargs...) end return Δt end end for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_pullback!) - @eval function MatrixAlgebraKit.$pullback_trunc!(Δt::AbstractTensorMap, - t::AbstractTensorMap, - F, ΔF; kwargs...) + @eval function MAK.$pullback_trunc!(Δt::AbstractTensorMap, + t::AbstractTensorMap, + F, ΔF; kwargs...) foreachblock(Δt, t) do c, (Δb, b) Fc = block.(F, Ref(c)) ΔFc = block.(ΔF, Ref(c)) - return $pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...) + return MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...) end return Δt end diff --git a/src/tensors/factorizations/truncation.jl b/src/tensors/factorizations/truncation.jl index 76b02fda0..dcf64177b 100644 --- a/src/tensors/factorizations/truncation.jl +++ b/src/tensors/factorizations/truncation.jl @@ -58,9 +58,9 @@ function truncate_diagonal!(Ddst::DiagonalTensorMap, Dsrc::DiagonalTensorMap, in return Ddst end -function truncate(::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3,AbstractTensorMap}, - strategy::TruncationStrategy) - ind = findtruncated_svd(diagview(S), strategy) +function MAK.truncate(::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3,AbstractTensorMap}, + strategy::TruncationStrategy) + ind = MAK.findtruncated_svd(diagview(S), strategy) V_truncated = truncate_space(space(S, 1), ind) Ũ = similar(U, codomain(U) ← V_truncated) @@ -73,13 +73,13 @@ function truncate(::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3,AbstractTensorMap return (Ũ, S̃, Ṽᴴ), ind end -function truncate(::typeof(left_null!), - (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, - strategy::MatrixAlgebraKit.TruncationStrategy) +function MAK.truncate(::typeof(left_null!), + (U, S)::Tuple{AbstractTensorMap,AbstractTensorMap}, + strategy::MatrixAlgebraKit.TruncationStrategy) extended_S = SectorDict(c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) for (c, b) in blocks(S)) - ind = findtruncated(extended_S, strategy) + ind = MAK.findtruncated(extended_S, strategy) V_truncated = truncate_space(space(S, 1), ind) Ũ = similar(U, codomain(U) ← V_truncated) truncate_domain!(Ũ, U, ind) @@ -87,10 +87,10 @@ function truncate(::typeof(left_null!), end for f! in (:eig_trunc!, :eigh_trunc!) - @eval function truncate(::typeof($f!), - (D, V)::Tuple{DiagonalTensorMap,AbstractTensorMap}, - strategy::TruncationStrategy) - ind = findtruncated(diagview(D), strategy) + @eval function MAK.truncate(::typeof($f!), + (D, V)::Tuple{DiagonalTensorMap,AbstractTensorMap}, + strategy::TruncationStrategy) + ind = MAK.findtruncated(diagview(D), strategy) V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) @@ -138,21 +138,21 @@ end # findtruncated # ------------- # Generic fallback -function findtruncated_svd(values::SectorDict, strategy::TruncationStrategy) - return findtruncated(values, strategy) +function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationStrategy) + return MAK.findtruncated(values, strategy) end -function findtruncated(values::SectorDict, ::NoTruncation) +function MAK.findtruncated(values::SectorDict, ::NoTruncation) return SectorDict(c => Colon() for (c, b) in values) end -function findtruncated(values::SectorDict, strategy::TruncationByOrder) +function MAK.findtruncated(values::SectorDict, strategy::TruncationByOrder) perms = SectorDict(c => (sortperm(d; strategy.by, strategy.rev)) for (c, d) in values) values_sorted = SectorDict(c => d[perms[c]] for (c, d) in values) - inds = findtruncated_svd(values_sorted, truncrank(strategy.howmany)) + inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) return SectorDict(c => perms[c][I] for (c, I) in inds) end -function findtruncated_svd(values::SectorDict, strategy::TruncationByOrder) +function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByOrder) I = keytype(values) truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in values) totaldim = sum(dim(c) * d for (c, d) in truncdim; init=0) @@ -168,28 +168,28 @@ function findtruncated_svd(values::SectorDict, strategy::TruncationByOrder) return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim) end -function findtruncated(values::SectorDict, strategy::TruncationByFilter) +function MAK.findtruncated(values::SectorDict, strategy::TruncationByFilter) return SectorDict(c => findall(strategy.filter, d) for (c, d) in values) end -function findtruncated(values::SectorDict, strategy::TruncationByValue) +function MAK.findtruncated(values::SectorDict, strategy::TruncationByValue) atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) - return SectorDict(c => findtruncated(d, strategy′) for (c, d) in values) + return SectorDict(c => MAK.findtruncated(d, strategy′) for (c, d) in values) end -function findtruncated_svd(values::SectorDict, strategy::TruncationByValue) +function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByValue) atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) - return SectorDict(c => findtruncated_svd(d, strategy′) for (c, d) in values) + return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in values) end -function findtruncated(values::SectorDict, strategy::TruncationByError) +function MAK.findtruncated(values::SectorDict, strategy::TruncationByError) perms = SectorDict(c => sortperm(d; by=abs, rev=true) for (c, d) in values) values_sorted = SectorDict(c => d[perms[c]] for (c, d) in Sd) - inds = findtruncated_svd(values_sorted, truncrank(strategy.howmany)) + inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) return SectorDict(c => perms[c][I] for (c, I) in inds) end -function findtruncated_svd(values::SectorDict, strategy::TruncationByError) +function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByError) I = keytype(values) truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in values) by(c, v) = abs(v)^strategy.p * dim(c) @@ -207,23 +207,23 @@ function findtruncated_svd(values::SectorDict, strategy::TruncationByError) return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) end -function findtruncated(values::SectorDict, strategy::TruncationSpace) +function MAK.findtruncated(values::SectorDict, strategy::TruncationSpace) blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) - return SectorDict(c => findtruncated(d, blockstrategy(c)) for (c, d) in values) + return SectorDict(c => MAK.findtruncated(d, blockstrategy(c)) for (c, d) in values) end -function findtruncated_svd(values::SectorDict, strategy::TruncationSpace) +function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationSpace) blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) - return SectorDict(c => findtruncated_svd(d, blockstrategy(c)) for (c, d) in values) + return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in values) end -function findtruncated(values::SectorDict, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated, values), strategy) +function MAK.findtruncated(values::SectorDict, strategy::TruncationIntersection) + inds = map(Base.Fix1(MAK.findtruncated, values), strategy) return SectorDict(c => mapreduce(Base.Fix2(getindex, c), _ind_intersect, inds; init=trues(length(values[c]))) for c in intersect(map(keys, inds)...)) end -function findtruncated_svd(Sd::SectorDict, strategy::TruncationIntersection) - inds = map(Base.Fix1(findtruncated_svd, Sd), strategy) +function MAK.findtruncated_svd(Sd::SectorDict, strategy::TruncationIntersection) + inds = map(Base.Fix1(MAK.findtruncated_svd, Sd), strategy) return SectorDict(c => mapreduce(Base.Fix2(getindex, c), _ind_intersect, inds; init=trues(length(values[c]))) for c in intersect(map(keys, inds)...)) diff --git a/src/tensors/factorizations/utility.jl b/src/tensors/factorizations/utility.jl index a6721ee31..46e88c8a9 100644 --- a/src/tensors/factorizations/utility.jl +++ b/src/tensors/factorizations/utility.jl @@ -23,4 +23,4 @@ function _reverse!(t::AbstractTensorMap; dims=:) return t end -diagview(t::AbstractTensorMap) = SectorDict(c => diagview(b) for (c, b) in blocks(t)) +MAK.diagview(t::AbstractTensorMap) = SectorDict(c => diagview(b) for (c, b) in blocks(t)) From 23b844311237d6a9483644e48554d7cdcaa7df2f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 1 Oct 2025 21:16:46 -0400 Subject: [PATCH 120/126] copyto -> copy --- src/tensors/factorizations/matrixalgebrakit.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/tensors/factorizations/matrixalgebrakit.jl index 93c1fe312..59f24fd7f 100644 --- a/src/tensors/factorizations/matrixalgebrakit.jl +++ b/src/tensors/factorizations/matrixalgebrakit.jl @@ -36,7 +36,7 @@ for f! in (:qr_compact!, :qr_full!, factors′ = $f!(first(bs), factors, alg) # deal with the case where the output is not in-place for (f′, f) in zip(factors′, factors) - f′ === f || copyto!(f, f′) + f′ === f || copy!(f, f′) end return nothing end @@ -53,7 +53,7 @@ for f! in (:qr_null!, :lq_null!) foreachblock(t, N) do _, (b, n) n′ = $f!(b, n, alg) # deal with the case where the output is not the same as the input - n === n′ || copyto!(n, n′) + n === n′ || copy!(n, n′) return nothing end @@ -67,9 +67,9 @@ for f! in (:svd_vals!, :eig_vals!, :eigh_vals!) MAK.check_input($f!, t, N, alg) foreachblock(t, N) do _, (b, n) - n′ = $f!(b, n.diag, alg) + n′ = $f!(b, diagview(n), alg) # deal with the case where the output is not the same as the input - n.diag === n′ || copyto!(n, diagview(n′)) + diagview(n) === n′ || copy!(diagview(n), n′) return nothing end From 6c703e7e107eb63bec6df68faa38283ce00fb1f7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 2 Oct 2025 11:28:03 -0400 Subject: [PATCH 121/126] some slight test tweaks --- test/ad.jl | 8 ++++---- test/runtests.jl | 32 ++++++++++++++------------------ 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index dd820b21f..c6f03742c 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -151,16 +151,16 @@ ChainRulesTestUtils.test_method_tables() spacelist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), (Vect[Z2Irrep](0 => 1, 1 => 1), Vect[Z2Irrep](0 => 1, 1 => 2)', - Vect[Z2Irrep](0 => 3, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', Vect[Z2Irrep](0 => 2, 1 => 3), Vect[Z2Irrep](0 => 2, 1 => 2)), (Vect[FermionParity](0 => 1, 1 => 1), Vect[FermionParity](0 => 1, 1 => 2)', - Vect[FermionParity](0 => 2, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', Vect[FermionParity](0 => 2, 1 => 3), Vect[FermionParity](0 => 2, 1 => 2)), (Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), - Vect[U1Irrep](0 => 3, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)'), @@ -171,7 +171,7 @@ spacelist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)'), (Vect[FibonacciAnyon](:I => 1, :τ => 1), Vect[FibonacciAnyon](:I => 1, :τ => 2)', - Vect[FibonacciAnyon](:I => 3, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', Vect[FibonacciAnyon](:I => 2, :τ => 3), Vect[FibonacciAnyon](:I => 2, :τ => 2))) diff --git a/test/runtests.jl b/test/runtests.jl index ade9f7cb5..f9829c237 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,11 +59,7 @@ sectorlist = (Z2Irrep, Z3Irrep, Z4Irrep, Z3Irrep ⊠ Z4Irrep, Z2Irrep ⊠ FibonacciAnyon ⊠ FibonacciAnyon) # spaces -Vtr = (ℂ^3, - (ℂ^4)', - ℂ^5, - ℂ^6, - (ℂ^7)') +Vtr = (ℂ^2, (ℂ^3)', ℂ^4, ℂ^3, (ℂ^2)') Vℤ₂ = (ℂ[Z2Irrep](0 => 1, 1 => 1), ℂ[Z2Irrep](0 => 1, 1 => 2)', ℂ[Z2Irrep](0 => 3, 1 => 2)', @@ -71,12 +67,12 @@ Vℤ₂ = (ℂ[Z2Irrep](0 => 1, 1 => 1), ℂ[Z2Irrep](0 => 2, 1 => 5)) Vfℤ₂ = (ℂ[FermionParity](0 => 1, 1 => 1), ℂ[FermionParity](0 => 1, 1 => 2)', - ℂ[FermionParity](0 => 3, 1 => 2)', + ℂ[FermionParity](0 => 2, 1 => 1)', ℂ[FermionParity](0 => 2, 1 => 3), ℂ[FermionParity](0 => 2, 1 => 5)) -Vℤ₃ = (ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 2), - ℂ[Z3Irrep](0 => 3, 1 => 1, 2 => 1), - ℂ[Z3Irrep](0 => 2, 1 => 2, 2 => 1)', +Vℤ₃ = (ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 1), + ℂ[Z3Irrep](0 => 2, 1 => 1, 2 => 1), + ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 1)', ℂ[Z3Irrep](0 => 1, 1 => 2, 2 => 3), ℂ[Z3Irrep](0 => 1, 1 => 3, 2 => 3)') VU₁ = (ℂ[U1Irrep](0 => 1, 1 => 2, -1 => 2), @@ -118,16 +114,16 @@ Vfib = (Vect[FibonacciAnyon](:I => 1, :τ => 1), if !is_buildkite Ti = time() - include("fusiontrees.jl") - include("spaces.jl") - include("tensors.jl") - include("factorizations.jl") - include("diagonal.jl") - include("planar.jl") - if isempty(VERSION.prerelease) - include("ad.jl") + @time include("fusiontrees.jl") + @time include("spaces.jl") + @time include("tensors.jl") + @time include("factorizations.jl") + @time include("diagonal.jl") + @time include("planar.jl") + if !(Sys.isapple() && get(ENV, "CI", "false") == "true") && isempty(VERSION.prerelease) + @time include("ad.jl") end - include("bugfixes.jl") + @time include("bugfixes.jl") Tf = time() printstyled("Finished all tests in ", string(round((Tf - Ti) / 60; sigdigits=3)), From 1ed898a1d62dca242a1e50983e4145f0be5c8cc8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 2 Oct 2025 11:29:02 -0400 Subject: [PATCH 122/126] move factorizations folder --- src/TensorKit.jl | 2 +- src/{tensors => }/factorizations/adjoint.jl | 0 src/{tensors => }/factorizations/diagonal.jl | 0 src/{tensors => }/factorizations/factorizations.jl | 0 src/{tensors => }/factorizations/matrixalgebrakit.jl | 0 src/{tensors => }/factorizations/pullbacks.jl | 0 src/{tensors => }/factorizations/truncation.jl | 0 src/{tensors => }/factorizations/utility.jl | 0 8 files changed, 1 insertion(+), 1 deletion(-) rename src/{tensors => }/factorizations/adjoint.jl (100%) rename src/{tensors => }/factorizations/diagonal.jl (100%) rename src/{tensors => }/factorizations/factorizations.jl (100%) rename src/{tensors => }/factorizations/matrixalgebrakit.jl (100%) rename src/{tensors => }/factorizations/pullbacks.jl (100%) rename src/{tensors => }/factorizations/truncation.jl (100%) rename src/{tensors => }/factorizations/utility.jl (100%) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 235614904..ce6c7697b 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -223,7 +223,7 @@ include("tensors/indexmanipulations.jl") include("tensors/diagonal.jl") include("tensors/braidingtensor.jl") -include("tensors/factorizations/factorizations.jl") +include("factorizations/factorizations.jl") using .Factorizations # # Planar macros and related functionality diff --git a/src/tensors/factorizations/adjoint.jl b/src/factorizations/adjoint.jl similarity index 100% rename from src/tensors/factorizations/adjoint.jl rename to src/factorizations/adjoint.jl diff --git a/src/tensors/factorizations/diagonal.jl b/src/factorizations/diagonal.jl similarity index 100% rename from src/tensors/factorizations/diagonal.jl rename to src/factorizations/diagonal.jl diff --git a/src/tensors/factorizations/factorizations.jl b/src/factorizations/factorizations.jl similarity index 100% rename from src/tensors/factorizations/factorizations.jl rename to src/factorizations/factorizations.jl diff --git a/src/tensors/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl similarity index 100% rename from src/tensors/factorizations/matrixalgebrakit.jl rename to src/factorizations/matrixalgebrakit.jl diff --git a/src/tensors/factorizations/pullbacks.jl b/src/factorizations/pullbacks.jl similarity index 100% rename from src/tensors/factorizations/pullbacks.jl rename to src/factorizations/pullbacks.jl diff --git a/src/tensors/factorizations/truncation.jl b/src/factorizations/truncation.jl similarity index 100% rename from src/tensors/factorizations/truncation.jl rename to src/factorizations/truncation.jl diff --git a/src/tensors/factorizations/utility.jl b/src/factorizations/utility.jl similarity index 100% rename from src/tensors/factorizations/utility.jl rename to src/factorizations/utility.jl From 01019af71f9144e08ded52342e2b1cd21d3db26f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 2 Oct 2025 13:22:16 -0400 Subject: [PATCH 123/126] fix some AD tests --- test/ad.jl | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index c6f03742c..ddd419db1 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,6 @@ using ChainRulesCore using ChainRulesTestUtils -using FiniteDifferences: FiniteDifferences +using FiniteDifferences: FiniteDifferences, central_fdm, forward_fdm using Random using LinearAlgebra using Zygote @@ -302,12 +302,25 @@ for V in spacelist t1 = randn(T, V[1] ← V[1]) t2 = randn(T, V[2] ← V[2]) d = DiagonalTensorMap{T}(undef, V[1]) - (T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data) d2 = DiagonalTensorMap{T}(undef, V[1]) - (T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data) + d3 = DiagonalTensorMap{T}(undef, V[1]) + if (T <: Real && f === sqrt) + # ensuring no square root of negative numbers + randexp!(d.data) + d.data .+= 5 + randexp!(d2.data) + d2.data .+= 5 + randexp!(d3.data) + d3.data .+= 5 + else + randn!(d.data) + randn!(d2.data) + randn!(d3.data) + end + test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) - test_rrule(f, d; check_inferred, output_tangent=d2) + test_rrule(f, d ⊢ d2; check_inferred, output_tangent=d3) end end @@ -516,7 +529,7 @@ for V in spacelist test_ad_rrule(last ∘ eig_full, t; output_tangent=Δv, atol, rtol) test_ad_rrule(eig_full, t; output_tangent=(Δd2, Δv), atol, rtol) - add!(t, t') + t += t' d, v = eigh_full(t) Δv = rand_tangent(v) Δd = rand_tangent(d) From fe980c07859091894119c30dcf1b6745bab0002d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 2 Oct 2025 21:51:35 -0400 Subject: [PATCH 124/126] use truncate_space --- src/factorizations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index dcf64177b..0b94bcc61 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -91,7 +91,7 @@ for f! in (:eig_trunc!, :eigh_trunc!) (D, V)::Tuple{DiagonalTensorMap,AbstractTensorMap}, strategy::TruncationStrategy) ind = MAK.findtruncated(diagview(D), strategy) - V_truncated = spacetype(D)(c => length(I) for (c, I) in ind) + V_truncated = truncate_space(space(D, 1), ind) D̃ = DiagonalTensorMap{scalartype(D)}(undef, V_truncated) truncate_diagonal!(D̃, D, ind) From f8463bf2bdd77ea4971b25611c53d36330ef7d02 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 3 Oct 2025 08:05:24 -0400 Subject: [PATCH 125/126] Update src/factorizations/truncation.jl Co-authored-by: Jutho --- src/factorizations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 0b94bcc61..35ae0df44 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -53,7 +53,7 @@ function truncate_diagonal!(Ddst::DiagonalTensorMap, Dsrc::DiagonalTensorMap, in for (c, b) in blocks(Ddst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(diagview(b), @view(diagview(block(Dsrc, c))[I])) + copy!(diagview(b), view(diagview(block(Dsrc, c)), I)) end return Ddst end From 4a5db98e7320741516750254ffa2e43f94f2a96d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 3 Oct 2025 08:18:34 -0400 Subject: [PATCH 126/126] final simplifications --- src/factorizations/diagonal.jl | 4 ++-- src/factorizations/matrixalgebrakit.jl | 22 ---------------------- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index 1cbaedafe..8fa33fd20 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -127,7 +127,7 @@ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!) end end -function MAK.check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm) +function MAK.check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) @@ -147,7 +147,7 @@ function MAK.check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::Diagon return nothing end -function MAK.check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, +function MAK.check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm) domain(t) == codomain(t) || throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) diff --git a/src/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl index 59f24fd7f..e67290603 100644 --- a/src/factorizations/matrixalgebrakit.jl +++ b/src/factorizations/matrixalgebrakit.jl @@ -212,28 +212,6 @@ function MAK.check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::Abstra return nothing end -function MAK.check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::AbstractAlgorithm) - domain(t) == codomain(t) || - throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) - - D, V = DV - - # type checks - @assert D isa DiagonalTensorMap - @assert V isa AbstractTensorMap - - # scalartype checks - @check_scalar D t - @check_scalar V t - - # space checks - V_D = fuse(domain(t)) - @check_space(D, V_D ← V_D) - @check_space(V, codomain(t) ← V_D) - - return nothing -end - function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::AbstractAlgorithm) @check_scalar D t real @assert D isa DiagonalTensorMap