diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index 5782f093c..4fd32e349 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -82,6 +82,7 @@ spacetype(::Type{<:AbstractTensorMap{<:Any,S}}) where {S} sectortype(::Type{TT}) where {TT<:AbstractTensorMap} field(::Type{TT}) where {TT<:AbstractTensorMap} storagetype +blocktype ``` To obtain information about the indices, you can use: diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 6882d10db..ccbdd94ff 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -184,7 +184,7 @@ include("spaces/vectorspaces.jl") #------------------------------------- # general definitions include("tensors/abstracttensor.jl") -# include("tensors/tensortreeiterator.jl") +include("tensors/blockiterator.jl") include("tensors/tensor.jl") include("tensors/adjoint.jl") include("tensors/linalg.jl") diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index b67ba51bb..7c24d24f2 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -88,12 +88,11 @@ function blocksectors(W::HomSpace) N₁ = length(codom) N₂ = length(dom) I = sectortype(W) - if N₁ == 0 || N₂ == 0 - return (one(I),) - elseif N₂ <= N₁ - return sort!(filter!(c -> hasblock(codom, c), collect(blocksectors(dom)))) + # TODO: is sort! still necessary now that blocksectors of ProductSpace is sorted? + if N₂ <= N₁ + return sort!(filter!(c -> hasblock(codom, c), blocksectors(dom))) else - return sort!(filter!(c -> hasblock(dom, c), collect(blocksectors(codom)))) + return sort!(filter!(c -> hasblock(dom, c), blocksectors(codom))) end end @@ -349,3 +348,20 @@ function fusionblockstructure(W::HomSpace, ::GlobalLRUCache) end return structure end + +# Diagonal ranges +#---------------- +# TODO: is this something we want to cache? +function diagonalblockstructure(W::HomSpace) + ((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) || + throw(SpaceMismatch("Diagonal only support on V←V with a single space V")) + structure = SectorDict{sectortype(W),UnitRange{Int}}() # range + offset = 0 + dom = domain(W)[1] + for c in blocksectors(W) + d = dim(dom, c) + structure[c] = offset .+ (1:d) + offset += d + end + return structure +end diff --git a/src/spaces/productspace.jl b/src/spaces/productspace.jl index b07367919..578804f30 100644 --- a/src/spaces/productspace.jl +++ b/src/spaces/productspace.jl @@ -162,7 +162,7 @@ function blocksectors(P::ProductSpace{S,N}) where {S,N} end end end - return bs + return sort!(bs) end """ diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 6913cb1bb..6bea00363 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -195,6 +195,7 @@ sectortype(t::AbstractTensorMap) = sectortype(typeof(t)) InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t)) field(t::AbstractTensorMap) = field(typeof(t)) storagetype(t::AbstractTensorMap) = storagetype(typeof(t)) +blocktype(t::AbstractTensorMap) = blocktype(typeof(t)) similarstoragetype(t::AbstractTensorMap, T=scalartype(t)) = similarstoragetype(typeof(t), T) numout(t::AbstractTensorMap) = numout(typeof(t)) @@ -310,6 +311,15 @@ Return the matrix block of a tensor corresponding to a coupled sector `c`. See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasblock`](@ref). """ block +@doc """ + blocktype(t) + +Return the type of the matrix blocks of a tensor. +""" blocktype +function blocktype(::Type{T}) where {T<:AbstractTensorMap} + return Core.Compiler.return_type(block, Tuple{T,sectortype(T)}) +end + # Derived indexing behavior for tensors with trivial symmetry #------------------------------------------------------------- using TensorKit.Strided: SliceIndex diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index de7914192..b47c7f2d3 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -25,11 +25,21 @@ storagetype(::Type{AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT} #---------------------- block(t::AdjointTensorMap, s::Sector) = block(parent(t), s)' -function blocks(t::AdjointTensorMap) - iter = Base.Iterators.map(blocks(parent(t))) do (c, b) - return c => b' - end - return iter +blocks(t::AdjointTensorMap) = BlockIterator(t, blocks(parent(t))) + +function blocktype(::Type{AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT} + return Base.promote_op(adjoint, blocktype(TT)) +end + +function Base.iterate(iter::BlockIterator{<:AdjointTensorMap}, state...) + next = iterate(iter.structure, state...) + isnothing(next) && return next + (c, b), newstate = next + return c => adjoint(b), newstate +end + +function Base.getindex(iter::BlockIterator{<:AdjointTensorMap}, c::Sector) + return adjoint(Base.getindex(iter.structure, c)) end function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂}, diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl new file mode 100644 index 000000000..b4ec4b87b --- /dev/null +++ b/src/tensors/blockiterator.jl @@ -0,0 +1,15 @@ +""" + struct BlockIterator{T<:AbstractTensorMap,S} + +Iterator over the blocks of type `T`, possibly holding some pre-computed data of type `S` +""" +struct BlockIterator{T<:AbstractTensorMap,S} + t::T + structure::S +end + +Base.IteratorSize(::BlockIterator) = Base.HasLength() +Base.IteratorEltype(::BlockIterator) = Base.HasEltype() +Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T) +Base.length(iter::BlockIterator) = length(iter.structure) +Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index ebae85ccc..cce5624ea 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -110,7 +110,23 @@ function block(d::DiagonalTensorMap, s::Sector) return Diagonal(view(d.data, 1:0)) end -# TODO: is relying on generic AbstractTensorMap blocks sufficient? +blocks(t::DiagonalTensorMap) = BlockIterator(t, diagonalblockstructure(space(t))) +function blocktype(::Type{DiagonalTensorMap{T,S,A}}) where {T,S,A} + return Diagonal{T,SubArray{T,1,A,Tuple{UnitRange{Int}},true}} +end + +function Base.iterate(iter::BlockIterator{<:DiagonalTensorMap}, state...) + next = iterate(iter.structure, state...) + isnothing(next) && return next + (c, r), newstate = next + return c => Diagonal(view(iter.t.data, r)), newstate +end + +function Base.getindex(iter::BlockIterator{<:DiagonalTensorMap}, c::Sector) + sectortype(iter.t) === typeof(c) || throw(SectorMismatch()) + r = get(iter.structure, c, 1:0) + return Diagonal(view(iter.t.data, r)) +end # Indexing and getting and setting the data at the subblock level #----------------------------------------------------------------- diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 7f071f104..abdb6d8c7 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -421,28 +421,35 @@ end # Getting and setting the data at the block level #------------------------------------------------- -function block(t::TensorMap, s::Sector) - sectortype(t) == typeof(s) || throw(SectorMismatch()) - structure = fusionblockstructure(t).blockstructure - (d₁, d₂), r = get(structure, s) do +block(t::TensorMap, c::Sector) = blocks(t)[c] + +blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure) + +function blocktype(::Type{TT}) where {TT<:TensorMap} + A = storagetype(TT) + T = eltype(A) + return Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}} +end + +function Base.iterate(iter::BlockIterator{<:TensorMap}, state...) + next = iterate(iter.structure, state...) + isnothing(next) && return next + (c, (sz, r)), newstate = next + return c => reshape(view(iter.t.data, r), sz), newstate +end + +function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector) + sectortype(iter.t) === typeof(c) || throw(SectorMismatch()) + (d₁, d₂), r = get(iter.structure, c) do # is s is not a key, at least one of the two dimensions will be zero: # it then does not matter where exactly we construct a view in `t.data`, # as it will have length zero anyway - d₁′ = blockdim(codomain(t), s) - d₂′ = blockdim(domain(t), s) + d₁′ = blockdim(codomain(iter.t), c) + d₂′ = blockdim(domain(iter.t), c) l = d₁′ * d₂′ return (d₁′, d₂′), 1:l end - return reshape(view(t.data, r), (d₁, d₂)) -end - -function blocks(t::TensorMap) - structure = fusionblockstructure(t).blockstructure - iter = Base.Iterators.map(structure) do (c, ((d₁, d₂), r)) - b = reshape(view(t.data, r), (d₁, d₂)) - return c => b - end - return iter + return reshape(view(iter.t.data, r), (d₁, d₂)) end # Indexing and getting and setting the data at the subblock level diff --git a/test/diagonal.jl b/test/diagonal.jl index 944fecb25..83cdee10e 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -14,6 +14,15 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @test space(t) == (V ← V) @test space(t') == (V ← V) @test dim(t) == dim(space(t)) + # blocks + bs = @constinferred blocks(t) + (c, b1), state = @constinferred Nothing iterate(bs) + @test c == first(blocksectors(V ← V)) + next = @constinferred Nothing iterate(bs, state) + b2 = @constinferred block(t, first(blocksectors(t))) + @test b1 == b2 + @test eltype(bs) === typeof(b1) === TensorKit.blocktype(t) + # basic linear algebra @test isa(@constinferred(norm(t)), real(T)) @test norm(t)^2 ≈ dot(t, t) α = rand(T) diff --git a/test/tensors.jl b/test/tensors.jl index 67565d341..0441dd313 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -91,6 +91,14 @@ for V in spacelist @test space(t) == (W ← one(W)) @test domain(t) == one(W) @test typeof(t) == TensorMap{T,spacetype(t),5,0,Vector{T}} + # blocks + bs = @constinferred blocks(t) + (c, b1), state = @constinferred Nothing iterate(bs) + @test c == first(blocksectors(W)) + next = @constinferred Nothing iterate(bs, state) + b2 = @constinferred block(t, first(blocksectors(t))) + @test b1 == b2 + @test eltype(bs) === typeof(b1) === TensorKit.blocktype(t) end end @timedtestset "Tensor Dict conversion" begin @@ -143,6 +151,15 @@ for V in spacelist @test dim(t) == dim(space(t)) @test codomain(t) == codomain(W) @test domain(t) == domain(W) + # blocks for adjoint + bs = @constinferred blocks(t') + (c, b1), state = @constinferred Nothing iterate(bs) + @test c == first(blocksectors(W')) + next = @constinferred Nothing iterate(bs, state) + b2 = @constinferred block(t', first(blocksectors(t'))) + @test b1 == b2 + @test eltype(bs) === typeof(b1) === TensorKit.blocktype(t') + # linear algebra @test isa(@constinferred(norm(t)), real(T)) @test norm(t)^2 ≈ dot(t, t) α = rand(T)