diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 29b553d49..fc1a5119f 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -76,7 +76,7 @@ export leftorth, rightorth, leftnull, rightnull, isposdef, isposdef!, ishermitian, sylvester, rank, cond export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition, repartition! -export catdomain, catcodomain +export catdomain, catcodomain, absorb, absorb! export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQpos, SVD, SDD, Polar diff --git a/src/spaces/gradedspace.jl b/src/spaces/gradedspace.jl index 7aa746076..63903a1ac 100644 --- a/src/spaces/gradedspace.jl +++ b/src/spaces/gradedspace.jl @@ -168,23 +168,19 @@ function fuse(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector} end function infimum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector} - if V₁.dual == V₂.dual - typeof(V₁)(c => min(dim(V₁, c), dim(V₂, c)) - for c in - union(sectors(V₁), sectors(V₂)), dual in V₁.dual) - else + Visdual = isdual(V₁) + Visdual == isdual(V₂) || throw(SpaceMismatch("Infimum of space and dual space does not exist")) - end + return typeof(V₁)((Visdual ? dual(c) : c) => min(dim(V₁, c), dim(V₂, c)) + for c in intersect(sectors(V₁), sectors(V₂)); dual=Visdual) end function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector} - if V₁.dual == V₂.dual - typeof(V₁)(c => max(dim(V₁, c), dim(V₂, c)) - for c in - union(sectors(V₁), sectors(V₂)), dual in V₁.dual) - else + Visdual = isdual(V₁) + Visdual == isdual(V₂) || throw(SpaceMismatch("Supremum of space and dual space does not exist")) - end + return typeof(V₁)((Visdual ? dual(c) : c) => max(dim(V₁, c), dim(V₂, c)) + for c in union(sectors(V₁), sectors(V₂)); dual=Visdual) end function Base.show(io::IO, V::GradedSpace{I}) where {I<:Sector} diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index d6a06cedd..92188b9ef 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -125,6 +125,13 @@ function dim(W::HomSpace) return d end +""" + fusiontrees(W::HomSpace) + +Return the fusiontrees corresponding to all valid fusion channels of a given `HomSpace`. +""" +fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist + # Operations on HomSpaces # ----------------------- """ diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index f29bdf809..5eba8414c 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -512,6 +512,38 @@ function catcodomain(t1::TT, t2::TT) where {S,N₂,TT<:AbstractTensorMap{<:Any,S return t end +""" + absorb(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) + absorb!(tdst::AbstactTensorMap, tsrc::AbstractTensorMap) + +Absorb the contents of `tsrc` into `tdst`, which may have different sizes of data. +This is equivalent to the following operation on dense arrays, but also works for symmetric +tensors. Note also that this only overwrites the regions that are shared, and will do +nothing on the ones that are not, so it is up to the user to properly initialize the +destination. + +```julia +sub_axes = map((x, y) -> 1:min(x, y), size(tdst), size(tsrc)) +tdst[sub_axes...] .= tsrc[sub_axes...] +``` +""" +absorb(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) = absorb!(copy(tdst), tsrc) +function absorb!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) + numin(tdst) == numin(tsrc) && numout(tdst) == numout(tsrc) || + throw(DimensionError("Incompatible number of indices for source and destination")) + S = spacetype(tdst) + S == spacetype(tsrc) || throw(SpaceMismatch("incompatible spacetypes")) + dom = mapreduce(infimum, ⊗, domain(tdst), domain(tsrc); init=one(S)) + cod = mapreduce(infimum, ⊗, codomain(tdst), codomain(tsrc); init=one(S)) + for (f1, f2) in fusiontrees(cod ← dom) + @inbounds data_dst = tdst[f1, f2] + @inbounds data_src = tsrc[f1, f2] + sub_axes = map(Base.OneTo ∘ min, size(data_dst), size(data_src)) + data_dst[sub_axes...] .= data_src[sub_axes...] + end + return tdst +end + # tensor product of tensors """ ⊗(t1::AbstractTensorMap, t2::AbstractTensorMap, ...) -> TensorMap diff --git a/test/tensors.jl b/test/tensors.jl index 25bc157b9..30526f2c1 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -739,6 +739,27 @@ for V in spacelist @test t ≈ t′ end end + @timedtestset "Tensor absorpsion" begin + # absorbing small into large + t1 = zeros(V1 ⊕ V1, V2 ⊗ V3) + t2 = rand(V1, V2 ⊗ V3) + t3 = @constinferred absorb(t1, t2) + @test norm(t3) ≈ norm(t2) + @test norm(t1) == 0 + t4 = @constinferred absorb!(t1, t2) + @test t1 === t4 + @test t3 ≈ t4 + + # absorbing large into small + t1 = rand(V1 ⊕ V1, V2 ⊗ V3) + t2 = zeros(V1, V2 ⊗ V3) + t3 = @constinferred absorb(t2, t1) + @test norm(t3) < norm(t1) + @test norm(t2) == 0 + t4 = @constinferred absorb!(t2, t1) + @test t2 === t4 + @test t3 ≈ t4 + end end TensorKit.empty_globalcaches!() end