Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/lib/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ spacetype(::Type{<:AbstractTensorMap{<:Any,S}}) where {S}
sectortype(::Type{TT}) where {TT<:AbstractTensorMap}
field(::Type{TT}) where {TT<:AbstractTensorMap}
storagetype
blocktype
```

To obtain information about the indices, you can use:
Expand Down
2 changes: 1 addition & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ include("spaces/vectorspaces.jl")
#-------------------------------------
# general definitions
include("tensors/abstracttensor.jl")
# include("tensors/tensortreeiterator.jl")
include("tensors/blockiterator.jl")
include("tensors/tensor.jl")
include("tensors/adjoint.jl")
include("tensors/linalg.jl")
Expand Down
26 changes: 21 additions & 5 deletions src/spaces/homspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,11 @@ function blocksectors(W::HomSpace)
N₁ = length(codom)
N₂ = length(dom)
I = sectortype(W)
if N₁ == 0 || N₂ == 0
return (one(I),)
elseif N₂ <= N₁
return sort!(filter!(c -> hasblock(codom, c), collect(blocksectors(dom))))
# TODO: is sort! still necessary now that blocksectors of ProductSpace is sorted?
if N₂ <= N₁
return sort!(filter!(c -> hasblock(codom, c), blocksectors(dom)))
else
return sort!(filter!(c -> hasblock(dom, c), collect(blocksectors(codom))))
return sort!(filter!(c -> hasblock(dom, c), blocksectors(codom)))
end
end

Expand Down Expand Up @@ -349,3 +348,20 @@ function fusionblockstructure(W::HomSpace, ::GlobalLRUCache)
end
return structure
end

# Diagonal ranges
#----------------
# TODO: is this something we want to cache?
function diagonalblockstructure(W::HomSpace)
((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) ||
throw(SpaceMismatch("Diagonal only support on V←V with a single space V"))
structure = SectorDict{sectortype(W),UnitRange{Int}}() # range
offset = 0
dom = domain(W)[1]
for c in blocksectors(W)
d = dim(dom, c)
structure[c] = offset .+ (1:d)
offset += d
end
return structure
end
2 changes: 1 addition & 1 deletion src/spaces/productspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ function blocksectors(P::ProductSpace{S,N}) where {S,N}
end
end
end
return bs
return sort!(bs)
end

"""
Expand Down
10 changes: 10 additions & 0 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
field(t::AbstractTensorMap) = field(typeof(t))
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
similarstoragetype(t::AbstractTensorMap, T=scalartype(t)) = similarstoragetype(typeof(t), T)

numout(t::AbstractTensorMap) = numout(typeof(t))
Expand Down Expand Up @@ -310,6 +311,15 @@
See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasblock`](@ref).
""" block

@doc """
blocktype(t)

Return the type of the matrix blocks of a tensor.
""" blocktype
function blocktype(::Type{T}) where {T<:AbstractTensorMap}
return Core.Compiler.return_type(block, Tuple{T,sectortype(T)})

Check warning on line 320 in src/tensors/abstracttensor.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/abstracttensor.jl#L319-L320

Added lines #L319 - L320 were not covered by tests
end

# Derived indexing behavior for tensors with trivial symmetry
#-------------------------------------------------------------
using TensorKit.Strided: SliceIndex
Expand Down
20 changes: 15 additions & 5 deletions src/tensors/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@
#----------------------
block(t::AdjointTensorMap, s::Sector) = block(parent(t), s)'

function blocks(t::AdjointTensorMap)
iter = Base.Iterators.map(blocks(parent(t))) do (c, b)
return c => b'
end
return iter
blocks(t::AdjointTensorMap) = BlockIterator(t, blocks(parent(t)))

function blocktype(::Type{AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT}
return Base.promote_op(adjoint, blocktype(TT))
end

function Base.iterate(iter::BlockIterator{<:AdjointTensorMap}, state...)
next = iterate(iter.structure, state...)
isnothing(next) && return next
(c, b), newstate = next
return c => adjoint(b), newstate
end

function Base.getindex(iter::BlockIterator{<:AdjointTensorMap}, c::Sector)
return adjoint(Base.getindex(iter.structure, c))

Check warning on line 42 in src/tensors/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/adjoint.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end

function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂},
Expand Down
15 changes: 15 additions & 0 deletions src/tensors/blockiterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
struct BlockIterator{T<:AbstractTensorMap,S}

Iterator over the blocks of type `T`, possibly holding some pre-computed data of type `S`
"""
struct BlockIterator{T<:AbstractTensorMap,S}
t::T
structure::S
end

Base.IteratorSize(::BlockIterator) = Base.HasLength()
Base.IteratorEltype(::BlockIterator) = Base.HasEltype()

Check warning on line 12 in src/tensors/blockiterator.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/blockiterator.jl#L12

Added line #L12 was not covered by tests
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
Base.length(iter::BlockIterator) = length(iter.structure)
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)
18 changes: 17 additions & 1 deletion src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,23 @@
return Diagonal(view(d.data, 1:0))
end

# TODO: is relying on generic AbstractTensorMap blocks sufficient?
blocks(t::DiagonalTensorMap) = BlockIterator(t, diagonalblockstructure(space(t)))
function blocktype(::Type{DiagonalTensorMap{T,S,A}}) where {T,S,A}
return Diagonal{T,SubArray{T,1,A,Tuple{UnitRange{Int}},true}}
end

function Base.iterate(iter::BlockIterator{<:DiagonalTensorMap}, state...)
next = iterate(iter.structure, state...)
isnothing(next) && return next
(c, r), newstate = next
return c => Diagonal(view(iter.t.data, r)), newstate
end

function Base.getindex(iter::BlockIterator{<:DiagonalTensorMap}, c::Sector)
sectortype(iter.t) === typeof(c) || throw(SectorMismatch())
r = get(iter.structure, c, 1:0)
return Diagonal(view(iter.t.data, r))

Check warning on line 128 in src/tensors/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/diagonal.jl#L125-L128

Added lines #L125 - L128 were not covered by tests
end

# Indexing and getting and setting the data at the subblock level
#-----------------------------------------------------------------
Expand Down
39 changes: 23 additions & 16 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,28 +421,35 @@ end

# Getting and setting the data at the block level
#-------------------------------------------------
function block(t::TensorMap, s::Sector)
sectortype(t) == typeof(s) || throw(SectorMismatch())
structure = fusionblockstructure(t).blockstructure
(d₁, d₂), r = get(structure, s) do
block(t::TensorMap, c::Sector) = blocks(t)[c]

blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)

function blocktype(::Type{TT}) where {TT<:TensorMap}
A = storagetype(TT)
T = eltype(A)
return Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}
end

function Base.iterate(iter::BlockIterator{<:TensorMap}, state...)
next = iterate(iter.structure, state...)
isnothing(next) && return next
(c, (sz, r)), newstate = next
return c => reshape(view(iter.t.data, r), sz), newstate
end

function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector)
sectortype(iter.t) === typeof(c) || throw(SectorMismatch())
(d₁, d₂), r = get(iter.structure, c) do
# is s is not a key, at least one of the two dimensions will be zero:
# it then does not matter where exactly we construct a view in `t.data`,
# as it will have length zero anyway
d₁′ = blockdim(codomain(t), s)
d₂′ = blockdim(domain(t), s)
d₁′ = blockdim(codomain(iter.t), c)
d₂′ = blockdim(domain(iter.t), c)
l = d₁′ * d₂′
return (d₁′, d₂′), 1:l
end
return reshape(view(t.data, r), (d₁, d₂))
end

function blocks(t::TensorMap)
structure = fusionblockstructure(t).blockstructure
iter = Base.Iterators.map(structure) do (c, ((d₁, d₂), r))
b = reshape(view(t.data, r), (d₁, d₂))
return c => b
end
return iter
return reshape(view(iter.t.data, r), (d₁, d₂))
end

# Indexing and getting and setting the data at the subblock level
Expand Down
9 changes: 9 additions & 0 deletions test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
@test space(t) == (V ← V)
@test space(t') == (V ← V)
@test dim(t) == dim(space(t))
# blocks
bs = @constinferred blocks(t)
(c, b1), state = @constinferred Nothing iterate(bs)
@test c == first(blocksectors(V ← V))
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t, first(blocksectors(t)))
@test b1 == b2
@test eltype(bs) === typeof(b1) === TensorKit.blocktype(t)
# basic linear algebra
@test isa(@constinferred(norm(t)), real(T))
@test norm(t)^2 ≈ dot(t, t)
α = rand(T)
Expand Down
17 changes: 17 additions & 0 deletions test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ for V in spacelist
@test space(t) == (W ← one(W))
@test domain(t) == one(W)
@test typeof(t) == TensorMap{T,spacetype(t),5,0,Vector{T}}
# blocks
bs = @constinferred blocks(t)
(c, b1), state = @constinferred Nothing iterate(bs)
@test c == first(blocksectors(W))
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t, first(blocksectors(t)))
@test b1 == b2
@test eltype(bs) === typeof(b1) === TensorKit.blocktype(t)
end
end
@timedtestset "Tensor Dict conversion" begin
Expand Down Expand Up @@ -143,6 +151,15 @@ for V in spacelist
@test dim(t) == dim(space(t))
@test codomain(t) == codomain(W)
@test domain(t) == domain(W)
# blocks for adjoint
bs = @constinferred blocks(t')
(c, b1), state = @constinferred Nothing iterate(bs)
@test c == first(blocksectors(W'))
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t', first(blocksectors(t')))
@test b1 == b2
@test eltype(bs) === typeof(b1) === TensorKit.blocktype(t')
# linear algebra
@test isa(@constinferred(norm(t)), real(T))
@test norm(t)^2 ≈ dot(t, t)
α = rand(T)
Expand Down
Loading